diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 3951541c..8c9964b8 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -389,6 +389,7 @@ class Preprocessors(object): # multi-modal preprocessor ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' + clip_preprocessor = 'clip-preprocessor' mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' # science preprocessor @@ -428,6 +429,8 @@ class Metrics(object): image_inpainting_metric = 'image-inpainting-metric' # metric for ocr NED = 'ned' + # metric for cross-modal retrieval + inbatch_recall = 'inbatch_recall' # metric for referring-video-object-segmentation task referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' @@ -474,6 +477,9 @@ class Hooks(object): # Compression SparsityHook = 'SparsityHook' + # CLIP logit_scale clamp + ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' + class LR_Schedulers(object): """learning rate scheduler is defined here diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 2b61c1ae..b9e402c5 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -24,6 +24,7 @@ class MetricKeys(object): ROUGE_1 = 'rouge-1' ROUGE_L = 'rouge-l' NED = 'ned' # ocr metric + BatchAcc = 'inbatch_t2i_recall_at_1' task_default_metrics = { diff --git a/modelscope/metrics/inbatch_recall_metric.py b/modelscope/metrics/inbatch_recall_metric.py new file mode 100644 index 00000000..d098a883 --- /dev/null +++ b/modelscope/metrics/inbatch_recall_metric.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np +import torch + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.inbatch_recall) +class InbatchRecallMetric(Metric): + """The metric computation class for in-batch retrieval classes. + + This metric class calculates in-batch image recall@1 for each input batch. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inbatch_t2i_hitcnts = [] + self.batch_sizes = [] + + def add(self, outputs: Dict, inputs: Dict): + image_features = outputs[OutputKeys.IMG_EMBEDDING] + text_features = outputs[OutputKeys.TEXT_EMBEDDING] + + assert type(image_features) == torch.Tensor and type( + text_features) == torch.Tensor + + with torch.no_grad(): + logits_per_image = image_features @ text_features.t() + logits_per_text = logits_per_image.t() + batch_size = logits_per_image.shape[0] + + ground_truth = torch.arange(batch_size).long() + ground_truth = ground_truth.to(image_features.device) + + inbatch_t2i_hitcnt = (logits_per_text.argmax(-1) == ground_truth + ).sum().float().item() + + self.inbatch_t2i_hitcnts.append(inbatch_t2i_hitcnt) + self.batch_sizes.append(batch_size) + + def evaluate(self): + assert len(self.inbatch_t2i_hitcnts) == len( + self.batch_sizes) and len(self.batch_sizes) > 0 + return { + MetricKeys.BatchAcc: + sum(self.inbatch_t2i_hitcnts) / sum(self.batch_sizes) + } diff --git a/modelscope/models/multi_modal/clip/model.py b/modelscope/models/multi_modal/clip/model.py index 92d9e11a..b1c84292 100644 --- a/modelscope/models/multi_modal/clip/model.py +++ b/modelscope/models/multi_modal/clip/model.py @@ -15,15 +15,13 @@ import os from collections import OrderedDict -from typing import Any, Dict, Iterable, List, Tuple, Union +from typing import Any, Dict, Tuple, Union import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from PIL import Image -from torchvision.transforms import Compose, Normalize, Resize, ToTensor from modelscope.metainfo import Models from modelscope.models import TorchModel @@ -506,21 +504,6 @@ def convert_weights(model: nn.Module): model.apply(_convert_weights_to_fp16) -def _convert_to_rgb(image): - return image.convert('RGB') - - -def image_transform(image_size=224): - transform = Compose([ - _convert_to_rgb, - Resize((image_size, image_size)), - ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711)), - ]) - return transform - - @MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) class CLIPForMultiModalEmbedding(TorchModel): @@ -540,72 +523,40 @@ class CLIPForMultiModalEmbedding(TorchModel): with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft: - model_info = json.load(fv) + self.model_info = json.load(fv) for k, v in json.load(ft).items(): - model_info[k] = v - - # image preprocess - self.img_preprocess = image_transform(model_info['image_resolution']) + self.model_info[k] = v - # text tokenizer vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' self.tokenizer = FullTokenizer(vocab_file=vocab_file) # initialize the model - self.clip_model = CLIP(**model_info, tokenizer=self.tokenizer) + self.clip_model = CLIP(**self.model_info, tokenizer=self.tokenizer) convert_weights(self.clip_model) # restore the pretrained weight checkpoint = torch.load( f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu') - sd = checkpoint['state_dict'] + sd = checkpoint[ + 'state_dict'] if 'state_dict' in checkpoint else checkpoint if next(iter(sd.items()))[0].startswith('module'): sd = {k[len('module.'):]: v for k, v in sd.items()} + # support the finetuned model + if next(iter(sd.items()))[0].startswith('clip_model'): + sd = {k[len('clip_model.'):]: v for k, v in sd.items()} self.clip_model.load_state_dict(sd) self.clip_model.eval() # place the model - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - if self.device == 'cuda': + self.device = 'cuda:{}'.format(int(os.environ.get( + 'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu' + if torch.cuda.is_available(): self.clip_model.to(self.device) - logger.info('Use GPU for inference') + logger.info('Use GPU {} for finetuning & inference'.format( + int(os.environ.get('LOCAL_RANK', 0)))) else: self.clip_model.float() - logger.info('Use CPU for inference') - - def tokenize(self, - texts: Union[str, List[str]], - context_length: int = 52) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all baseline models use 24 as the context length - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - all_tokens = [] - for text in texts: - all_tokens.append( - [self.tokenizer.vocab['[CLS]']] - + self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(text))[:context_length - 2] - + [self.tokenizer.vocab['[SEP]']]) - - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - assert len(tokens) <= context_length - result[i, :len(tokens)] = torch.tensor(tokens) - - return result + logger.info('Use CPU for finetuning & inference') def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: from modelscope.outputs import OutputKeys @@ -613,75 +564,36 @@ class CLIPForMultiModalEmbedding(TorchModel): OutputKeys.IMG_EMBEDDING: None, OutputKeys.TEXT_EMBEDDING: None } - if 'img' in input and input['img'] is not None: - image_input = input['img'] - - # single image input - if isinstance(image_input, Image.Image): - image_tensor = self.img_preprocess(image_input).unsqueeze(0) - # multi images input - elif isinstance(image_input, list): - if all([isinstance(elem, Image.Image) - for elem in image_input]): - image_tensor = torch.stack( - [self.img_preprocess(elem) for elem in image_input], - dim=0) - else: - unsupported_elem_type = [ - type(elem) for elem in image_input - if not isinstance(elem, Image.Image) - ][0] - raise TypeError( - f'img should be PIL.Image or List[PIL.Image], \ - but got a List containing one {unsupported_elem_type}' - ) - # others - else: - raise TypeError( - f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' - ) - - image_tensor = image_tensor.to(self.device) - - with torch.no_grad(): + mode = input.get('mode', ModeKeys.INFERENCE) + + # encode the image + if 'img' in input and isinstance(input['img'], torch.Tensor): + image_tensor = input['img'].to(self.device) + if image_tensor.dim() == 5 and image_tensor.shape[1] == 1: + image_tensor = image_tensor.squeeze(1) + + with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): image_features = self.clip_model.encode_image(image_tensor) image_features /= image_features.norm( dim=-1, keepdim=True) # l2-normalize output[OutputKeys.IMG_EMBEDDING] = image_features - if 'text' in input and input['text'] is not None: - text_input = input['text'] - - # single text input - if isinstance(text_input, str): - text_tensor = self.tokenize(text_input) - # multi texts input - elif isinstance(text_input, list): - if all([isinstance(elem, str) for elem in text_input]): - text_tensor = self.tokenize(text_input) - else: - unsupported_elem_type = [ - type(elem) for elem in text_input - if not isinstance(elem, str) - ][0] - raise TypeError( - f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' - ) - # others - else: - raise TypeError( - f'text should be str or List[str], but got {type(text_input)}' - ) - - text_tensor = text_tensor.to(self.device) - - with torch.no_grad(): + if 'text' in input and isinstance(input['text'], torch.Tensor): + text_tensor = input['text'].to(self.device) + if text_tensor.dim() == 3 and text_tensor.shape[1] == 1: + text_tensor = text_tensor.squeeze(1) + + with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): text_features = self.clip_model.encode_text(text_tensor) text_features /= text_features.norm( dim=-1, keepdim=True) # l2-normalize output[OutputKeys.TEXT_EMBEDDING] = text_features + if mode == ModeKeys.TRAIN: + output['logit_scale'] = (self.clip_model.logit_scale + * 1.0).exp().mean() + return output def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py index d3f15c23..18ee1dbf 100644 --- a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py +++ b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py @@ -1,10 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict +from typing import Any, Dict, Optional, Union from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal.clip.model import CLIPForMultiModalEmbedding from modelscope.pipelines.base import Input, Model, Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.multi_modal import CLIPPreprocessor, Preprocessor from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -17,7 +19,10 @@ logger = get_logger() Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) class MultiModalEmbeddingPipeline(Pipeline): - def __init__(self, model: str, device: str = 'gpu'): + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): """ use `model` and `preprocessor` to create a kws pipeline for prediction Args: @@ -29,14 +34,17 @@ class MultiModalEmbeddingPipeline(Pipeline): pipe_model = model else: raise NotImplementedError('model must be a single str') + pipe_model.eval() + if preprocessor is None: + if isinstance(pipe_model, CLIPForMultiModalEmbedding): + preprocessor = CLIPPreprocessor(pipe_model.model_dir) + else: + raise NotImplementedError - super().__init__(model=pipe_model) - - def preprocess(self, input: Input) -> Dict[str, Any]: - return input + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - return self.model(input) + return self.model(self.preprocess(input)) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 557b469a..17dffb48 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -3,8 +3,11 @@ import os.path as osp from io import BytesIO from typing import Any, Dict, List, Tuple, Union +import json import torch from PIL import Image +from timm.data import create_transform +from torchvision.transforms import Compose, Normalize, Resize, ToTensor from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Preprocessors @@ -107,6 +110,180 @@ class OfaPreprocessor(Preprocessor): eos_idx=self.tokenizer.eos_token_id) +def _convert_to_rgb(image): + return image.convert('RGB') + + +@PREPROCESSORS.register_module( + Fields.multi_modal, module_name=Preprocessors.clip_preprocessor) +class CLIPPreprocessor(Preprocessor): + + def __init__(self, + model_dir: str, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + mode: preprocessor mode (model mode) + """ + super().__init__(*args, **kwargs) + model_dir = model_dir if osp.exists(model_dir) else snapshot_download( + model_dir) + self.mode = mode + # text tokenizer + from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer + if 'tokenizer' in kwargs and isinstance(kwargs['tokenizer'], + FullTokenizer): + self.tokenizer = kwargs['tokenizer'] + else: + vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' + self.tokenizer = FullTokenizer(vocab_file=vocab_file) + # image preprocessor + if 'resolution' in kwargs and isinstance(kwargs['resolution'], int): + self.image_resolution = kwargs['resolution'] + else: + self.image_resolution = json.load( + open('{}/vision_model_config.json'.format( + model_dir)))['image_resolution'] + self.img_preprocess = self._build_image_transform() + # key mapping + # specify the input keys, compatible with training and inference whose key names may be different + self.input_keys = {'img': 'img', 'text': 'text'} + + def _build_image_transform(self): + + if self.mode == ModeKeys.TRAIN: + transform = create_transform( + input_size=self.image_resolution, + scale=(0.9, 1.0), + is_training=True, + color_jitter=None, + auto_augment='original', + interpolation='bicubic', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ) + transform = Compose(transform.transforms[:-3] + [_convert_to_rgb] + + transform.transforms[-3:]) + else: + transform = Compose([ + Resize((self.image_resolution, self.image_resolution), + interpolation=Image.BICUBIC), + _convert_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + return transform + + def tokenize(self, + texts: Union[str, List[str]], + context_length: int = 52) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all baseline models use 24 as the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + all_tokens.append( + [self.tokenizer.vocab['[CLS]']] + + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:context_length - 2] + + [self.tokenizer.vocab['[SEP]']]) + + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + def set_input_img_key(self, new_key: str): + self.input_keys['img'] = new_key + + def set_input_text_key(self, new_key: str): + self.input_keys['text'] = new_key + + def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, + **kwargs) -> Dict[str, Any]: + output = {} + # preprocess the image input + input_img_key = self.input_keys['img'] + if input_img_key in input and input[input_img_key] is not None: + image_input = input[input_img_key] + + # single image input + if isinstance(image_input, Image.Image): + image_tensor = self.img_preprocess(image_input).unsqueeze(0) + # multi images input + elif isinstance(image_input, list): + if all([isinstance(elem, Image.Image) + for elem in image_input]): + image_tensor = torch.stack( + [self.img_preprocess(elem) + for elem in image_input], # noqa + dim=0) # noqa + else: + unsupported_elem_type = [ + type(elem) for elem in image_input + if not isinstance(elem, Image.Image) + ][0] + raise TypeError( + f'img should be PIL.Image or List[PIL.Image], \ + but got a List containing one {unsupported_elem_type}' + ) + # others + else: + raise TypeError( + f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' + ) + output['img'] = image_tensor + + # preprocess the text input + input_text_key = self.input_keys['text'] + if input_text_key in input and input[input_text_key] is not None: + text_input = input[input_text_key] + + # single text input + if isinstance(text_input, str): + text_tensor = self.tokenize(text_input) + # multi texts input + elif isinstance(text_input, list): + if all([isinstance(elem, str) for elem in text_input]): + text_tensor = self.tokenize(text_input) + else: + unsupported_elem_type = [ + type(elem) for elem in text_input + if not isinstance(elem, str) + ][0] + raise TypeError( + f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' + ) + # others + else: + raise TypeError( + f'text should be str or List[str], but got {type(text_input)}' + ) + output['text'] = text_tensor + + return output + + @PREPROCESSORS.register_module( Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) class MPlugPreprocessor(Preprocessor): diff --git a/modelscope/trainers/hooks/clip_clamp_logit_scale_hook.py b/modelscope/trainers/hooks/clip_clamp_logit_scale_hook.py new file mode 100644 index 00000000..ce98e6c9 --- /dev/null +++ b/modelscope/trainers/hooks/clip_clamp_logit_scale_hook.py @@ -0,0 +1,18 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch + +from modelscope.metainfo import Hooks +from modelscope.trainers.multi_modal.clip.clip_trainer import CLIPTrainer +from .builder import HOOKS +from .hook import Hook + + +@HOOKS.register_module(module_name=Hooks.ClipClampLogitScaleHook) +class ClipClampLogitScaleHook(Hook): + """ClipClampLogitScaleHook hook which performs clamp on CLIP logit scale parameter after update""" + + def after_train_iter(self, trainer: CLIPTrainer): + """Called after every training iter to evaluate the results.""" + unwrapped_model = getattr(trainer.model, 'module', trainer.model) + logit_scale = unwrapped_model.clip_model.logit_scale + logit_scale.data = torch.clamp(logit_scale.data, 0, 4.6052) diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer.py b/modelscope/trainers/multi_modal/clip/clip_trainer.py index cbe83417..40c524ac 100644 --- a/modelscope/trainers/multi_modal/clip/clip_trainer.py +++ b/modelscope/trainers/multi_modal/clip/clip_trainer.py @@ -1,169 +1,206 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import math import os -from typing import Dict, Optional +from typing import Callable, Dict, Optional, Tuple, Union import torch -import torch.distributed as dist -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler +from torch import distributed as dist +from torch import nn +from torch.utils.data import Dataset from modelscope.metainfo import Trainers -from modelscope.models.base import Model -from modelscope.trainers.base import BaseTrainer +from modelscope.models.base import Model, TorchModel +from modelscope.models.multi_modal.clip.model import convert_models_to_fp32 +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.multi_modal import CLIPPreprocessor +from modelscope.trainers import EpochBasedTrainer from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.optimizer.builder import build_optimizer from modelscope.utils.config import Config -from modelscope.utils.constant import ModeKeys -from modelscope.utils.logger import get_logger -from .clip_trainer_utils import ImageWithCaptionDataset, get_optimizer +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, + ModeKeys) +from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule -logger = get_logger() + +def exclude(n): + return 'bn' in n or 'ln' in n or 'bias' in n or 'logit_scale' in n + + +def include(n): + return not exclude(n) @TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) -class CLIPTrainer(BaseTrainer): - - def __init__(self, cfg_file: str, model: str, device_id: int, *args, - **kwargs): - super().__init__(cfg_file) - - self.cfg = Config.from_file(cfg_file) - self.model = Model.from_pretrained(model) - self.device_id = device_id - self.total_epoch = self.cfg.train.epoch - self.train_batch_size = self.cfg.train.batch_size - self.val_batch_size = self.cfg.evaluation.batch_size - self.ckpt_dir = self.cfg.train.ckpt_dir - - self.train_dataset = ImageWithCaptionDataset( - json_file='{}/{}'.format(self.cfg.dataset.root_dir, - self.cfg.dataset.train_set), - img_dir=self.cfg.dataset.root_dir, - phase=ModeKeys.TRAIN) - self.val_dataset = ImageWithCaptionDataset( - json_file='{}/{}'.format(self.cfg.dataset.root_dir, - self.cfg.dataset.val_set), - img_dir=self.cfg.dataset.root_dir, - phase=ModeKeys.EVAL) - - def train(self, *args, **kwargs): - assert dist.is_initialized() - - self.model.clip_model.train() - self.model.clip_model.to(self.device_id) - ddp_model = torch.nn.parallel.DistributedDataParallel( - self.model.clip_model, device_ids=[ - self.device_id, - ]) - - optimizer = get_optimizer(ddp_model) - - for epoch in range(self.total_epoch): - train_sampler = DistributedSampler( - dataset=self.train_dataset, shuffle=True) - train_sampler.set_epoch(epoch) - - train_params = { - 'pin_memory': True, - 'collate_fn': None, - 'batch_size': self.train_batch_size, - 'shuffle': False, - 'drop_last': True, - 'sampler': train_sampler, - 'num_workers': 8 +class CLIPTrainer(EpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Union[Callable, Dict[str, + Callable]]] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Union[Preprocessor, + Dict[str, Preprocessor]]] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 42, + **kwargs): + model = Model.from_pretrained(model, revision=model_revision) + # for training & eval, we convert the model from FP16 back to FP32 + # to compatible with modelscope amp training + convert_models_to_fp32(model) + cfg = Config.from_file(cfg_file) + if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: + work_dir = cfg.train.work_dir + else: + work_dir = kwargs['work_dir'] + + # fetch the model name of CLIP model (base, large or large-336) + model_name = cfg.pretrained_model.model_name + + # world size + world_size = int(os.environ.get('WORLD_SIZE', 1)) + + # train step, optimizer and lr_scheduler + epoch_steps = math.ceil( + len(train_dataset) / # noqa + (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa + cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs + + if optimizers[0] is None: + named_parameters = list(model.named_parameters()) + gain_or_bias_params = [ + p for n, p in named_parameters + if exclude(n) and p.requires_grad + ] + rest_params = [ + p for n, p in named_parameters + if include(n) and p.requires_grad + ] + optimizer_hparams = get_optimizer_params( + model_name, cfg) # lr, wd, beta1, beta2, eps + optimizer_args = { + 'params': [ + { + 'params': gain_or_bias_params, + 'weight_decay': 0. + }, + { + 'params': rest_params, + 'weight_decay': optimizer_hparams['weight_decay'] + }, + ], + 'lr': + optimizer_hparams['lr'], + 'betas': + (optimizer_hparams['beta1'], optimizer_hparams['beta2']), + 'eps': + optimizer_hparams['eps'], + } + optimizer = build_optimizer( + model, cfg=cfg.train.optimizer, default_args=optimizer_args) + else: + optimizer = optimizers[0] + + if optimizers[1] is None: + lr_scheduler = get_schedule(optimizer, cfg.train.lr_scheduler) + else: + lr_scheduler = optimizers[1] + optimizers = (optimizer, lr_scheduler) + + # loss module + loss_img = nn.CrossEntropyLoss() + loss_txt = nn.CrossEntropyLoss() + self.loss_img = loss_img.cuda(int(os.environ.get('LOCAL_RANK', 0))) + self.loss_txt = loss_txt.cuda(int(os.environ.get('LOCAL_RANK', 0))) + self.loss_cfg = cfg.train.loss_cfg + + # launcher and use_fp16 + if 'launcher' not in kwargs and cfg.train.get('launcher', None): + kwargs['launcher'] = cfg.train.launcher + if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): + kwargs['use_fp16'] = cfg.train.use_fp16 + + # preprocessor + if preprocessor is None: + preprocessor = { + ConfigKeys.train: + CLIPPreprocessor( + model_dir=work_dir, + mode=ModeKeys.TRAIN, + tokenizer=model.tokenizer, + resolution=model.model_info['image_resolution']), + ConfigKeys.val: + CLIPPreprocessor( + model_dir=work_dir, + mode=ModeKeys.EVAL, + tokenizer=model.tokenizer, + resolution=model.model_info['image_resolution']), } - train_loader = DataLoader(self.train_dataset, **train_params) - - for batch_idx, (img_tensor, text_str_list, - img_id_list) in enumerate(train_loader): - text_info_list = [ - self.model.tokenize_text(tmp) for tmp in text_str_list - ] - text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], - dim=0) - text_masks_tensor = torch.cat( - [tmp[1] for tmp in text_info_list], dim=0) - - img_tensor = img_tensor.to(self.device_id, non_blocking=True) - img_id_list = img_id_list.to(self.device_id, non_blocking=True) - text_ids_tensor = text_ids_tensor.to( - self.device_id, non_blocking=True) - text_masks_tensor = text_masks_tensor.to( - self.device_id, non_blocking=True) - - loss = ddp_model((img_tensor, text_ids_tensor, - text_masks_tensor, img_id_list), - ModeKeys.TRAIN) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if batch_idx % 10 == 0: - logger.info( - 'epoch: {}, train batch {}/{}, loss={:.5f}, logit_scale={:.5f}' - .format(epoch, batch_idx, len(train_loader), - loss.item(), - ddp_model.module.logit_scale.exp().item())) - if dist.get_rank() == 0: - os.makedirs(self.ckpt_dir, exist_ok=True) - torch.save(ddp_model.module.state_dict(), - '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) - - def evaluate(self, - checkpoint_path: Optional[str] = None, - *args, - **kwargs) -> Dict[str, float]: - if checkpoint_path is not None: - checkpoint_params = torch.load(checkpoint_path, 'cpu') - self.model.clip_model.load_state_dict(checkpoint_params) - self.model.clip_model.eval() - self.model.clip_model.to(self.device_id) - - val_params = { - 'collate_fn': None, - 'batch_size': self.val_batch_size, - 'shuffle': False, - 'drop_last': False, - 'num_workers': 8 - } - val_loader = DataLoader(self.val_dataset, **val_params) - - tp_cnt_per_batch = [] - processed_cnt = 0 - with torch.no_grad(): - for batch_idx, (img_tensor, text_str_list, - img_id_list) in enumerate(val_loader): - text_info_list = [ - self.model.tokenize_text(tmp) for tmp in text_str_list - ] - text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], - dim=0) - text_masks_tensor = torch.cat( - [tmp[1] for tmp in text_info_list], dim=0) - - img_tensor = img_tensor.to(self.device_id, non_blocking=True) - img_id_list = img_id_list.to(self.device_id, non_blocking=True) - text_ids_tensor = text_ids_tensor.to( - self.device_id, non_blocking=True) - text_masks_tensor = text_masks_tensor.to( - self.device_id, non_blocking=True) - - img_feat = self.model.clip_model(img_tensor, input_type='img') - text_feat = self.model.clip_model( - (text_ids_tensor, text_masks_tensor), input_type='text') - - sim_mat = text_feat @ img_feat.t() - text_cnt, img_cnt = sim_mat.shape - top1_scores, match_ids = torch.max(sim_mat, dim=1) - - match_ids = match_ids.int() - gt_ids = torch.tensor(range(0, text_cnt)).to( - self.device_id, non_blocking=True).int() - error_cnt = torch.nonzero(match_ids - gt_ids) - processed_cnt += text_cnt - - tp_cnt_per_batch.append(text_cnt - 1.0 * error_cnt.numel()) - logger.info('current acc: {:.3f}'.format( - sum(tp_cnt_per_batch) / processed_cnt)) + # dataset related + self.dataset_cfg = cfg.dataset + if hasattr(self.dataset_cfg, 'column_map'): + # cases where dataset key names are not "img" and "text" + img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img') + preprocessor[ConfigKeys.train].set_input_img_key(img_key_name) + preprocessor[ConfigKeys.val].set_input_img_key(img_key_name) + text_key_name = getattr(self.dataset_cfg.column_map, 'text', + 'text') + preprocessor[ConfigKeys.train].set_input_text_key(text_key_name) + preprocessor[ConfigKeys.val].set_input_text_key(text_key_name) + self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size + + super().__init__( + model=model, + cfg_file=cfg_file, + arg_parse_fn=arg_parse_fn, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + preprocessor=preprocessor, + optimizers=optimizers, + seed=seed, + **kwargs, + ) + + def train_step(self, model, inputs): + model.train() + inputs['mode'] = ModeKeys.TRAIN + model_outputs = model.forward( + inputs + ) # {OutputKeys.IMG_EMBEDDING: Tensor(batch_size, dim), OutputKeys.TEXT_EMBEDDING: Tensor(batch_size, dim)} + loss = get_loss(model_outputs, self.loss_img, self.loss_txt, + self.loss_cfg) + train_outputs = {'loss': loss} + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + unwrapped_model = getattr(model, 'module', model) + log_vars[ + 'logit_scale'] = unwrapped_model.clip_model.logit_scale.data.clone( + ).item() # noqa + log_vars['global_batch_size'] = int(self.global_batch_size) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + self.train_outputs = train_outputs diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py index 4e150fe7..fed255de 100644 --- a/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py +++ b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py @@ -1,94 +1,125 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. +import math import os -import random +from functools import partial +from inspect import unwrap -import json import torch -import torch.nn.functional as F -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms - -from modelscope.utils.constant import ModeKeys - -train_transform = transforms.Compose([ - transforms.RandomResizedCrop( - 224, scale=(0.5, 1.0), interpolation=Image.BICUBIC), - transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], - p=0.8), - transforms.RandomGrayscale(p=0.2), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711)) -]) - -val_transform = transforms.Compose([ - transforms.Resize((224, 224), interpolation=Image.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711)) -]) - - -class ImageWithCaptionDataset(Dataset): - - def __init__(self, json_file, img_dir, phase): - self.annotations = json.load(open(json_file)) - self.img_dir = img_dir - if phase == ModeKeys.TRAIN: - self.transform = train_transform - elif phase == ModeKeys.EVAL: - self.transform = val_transform - - self.img_name2img_id = {} - for anno_dict in self.annotations: - img_name = anno_dict['image'] - if img_name not in self.img_name2img_id: - self.img_name2img_id[img_name] = len(self.img_name2img_id) - - def __len__(self): - return len(self.annotations) - - def __getitem__(self, index): - anno_dict = self.annotations[index] - - img_path = os.path.join(self.img_dir, anno_dict['image']) - img_pil = Image.open(img_path).convert('RGB') - img_th = self.transform(img_pil) - img_id = self.img_name2img_id[anno_dict['image']] - - text_str = random.choice(anno_dict['caption']) - - return img_th, text_str, img_id - - -def get_params_groups(ddp_model, weight_decay): - decay = [] - no_decay = [] - for name, param in ddp_model.named_parameters(): - if not param.requires_grad: - continue - if len(param.shape) == 1 or name.endswith('.bias'): - no_decay.append(param) - else: - decay.append(param) - params_groups = [{ - 'params': no_decay, - 'weight_decay': 0. - }, { - 'params': decay, - 'weight_decay': weight_decay - }] - return params_groups - - -def get_optimizer(ddp_model): - from torch.optim import AdamW - lr_init = 1e-5 - betas = [0.9, 0.999] - weight_decay = 0.02 - params_groups = get_params_groups(ddp_model, weight_decay=weight_decay) - return AdamW( - params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) +import torch.distributed as dist +from torch.optim.lr_scheduler import LambdaLR + +from modelscope.outputs import OutputKeys + + +def get_optimizer_params(model_name, cfg): + # get default params + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + # base model + if model_name in ['damo/multi-modal_clip-vit-base-patch16_zh']: + params = { + 'lr': 5.0e-4, + 'beta1': 0.9, + 'beta2': 0.98, + 'eps': 1.0e-6, + 'weight_decay': 0.0 + } + # large models + elif model_name in [ + 'damo/multi-modal_clip-vit-large-patch14_zh', + 'damo/multi-modal_clip-vit-large-patch14_336_zh' + ]: + params = { + 'lr': 4.0e-4, + 'beta1': 0.9, + 'beta2': 0.98, + 'eps': 1.0e-6, + 'weight_decay': 0.0 + } + else: + params = { + 'lr': 5.0e-4, + 'beta1': 0.9, + 'beta2': 0.999, + 'eps': 1.0e-8, + 'weight_decay': 0.0 + } + # override with config params + for key in ['lr', 'beta1', 'beta2', 'eps', 'weight_decay']: + if hasattr(cfg.train, 'optimizer_hparams'): + params[key] = getattr(cfg.train.optimizer_hparams, key, + params[key]) + return params + + +def get_loss(model_outputs, loss_img, loss_txt, loss_cfg): + image_features = model_outputs[OutputKeys.IMG_EMBEDDING] + text_features = model_outputs[OutputKeys.TEXT_EMBEDDING] + logit_scale = model_outputs['logit_scale'] + logit_scale = logit_scale.mean() + if loss_cfg.aggregate and int(os.environ.get('WORLD_SIZE', 1)) > 1: + world_size = dist.get_world_size() + rank = dist.get_rank() + + # We gather tensors from all gpus to get more negatives to contrast with. + gathered_image_features = [ + torch.zeros_like(image_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + + all_image_features = torch.cat([image_features] + + gathered_image_features[:rank] + + gathered_image_features[rank + 1:]) + all_text_features = torch.cat([text_features] + + gathered_text_features[:rank] + + gathered_text_features[rank + 1:]) + + # this is needed to send gradients back everywhere. + logits_per_image = logit_scale * all_image_features @ all_text_features.t( + ) + logits_per_text = logits_per_image.t() + + else: + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + ground_truth = torch.arange(len(logits_per_image)).long() + ground_truth = ground_truth.cuda( + int(os.environ.get('LOCAL_RANK', 0)), non_blocking=True) + + total_loss = (loss_img(logits_per_image, ground_truth) + + loss_txt(logits_per_text, ground_truth)) / 2 + + return total_loss + + +def lr_lambda(num_warmup_steps, num_training_steps, num_cycles, current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps)) + return max( + 0.0, + 0.5 * # noqa + (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) # noqa + + +def get_schedule(optimizer, + scheduler, + num_cycles: float = 0.5, + last_epoch: int = -1): + num_warmup_steps = int(scheduler.warmup_proportion + * scheduler.num_train_steps) + num_training_steps = scheduler.num_train_steps + + return LambdaLR( + optimizer, + partial(lr_lambda, num_warmup_steps, num_training_steps, num_cycles), + last_epoch) diff --git a/tests/pipelines/test_multi_modal_embedding.py b/tests/pipelines/test_multi_modal_embedding.py index ee9cdb1f..7eddc690 100644 --- a/tests/pipelines/test_multi_modal_embedding.py +++ b/tests/pipelines/test_multi_modal_embedding.py @@ -24,7 +24,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): def test_run(self): pipeline_multi_modal_embedding = pipeline( Tasks.multi_modal_embedding, model=self.model_id) - text_embedding = pipeline_multi_modal_embedding( + text_embedding = pipeline_multi_modal_embedding.forward( self.test_input)[OutputKeys.TEXT_EMBEDDING] print('l1-norm: {}'.format( torch.norm(text_embedding, p=1, dim=-1).item())) @@ -36,7 +36,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): model = Model.from_pretrained(self.model_id) pipeline_multi_modal_embedding = pipeline( task=Tasks.multi_modal_embedding, model=model) - text_embedding = pipeline_multi_modal_embedding( + text_embedding = pipeline_multi_modal_embedding.forward( self.test_input)[OutputKeys.TEXT_EMBEDDING] print('l1-norm: {}'.format( torch.norm(text_embedding, p=1, dim=-1).item())) @@ -47,7 +47,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): def test_run_with_default_model(self): pipeline_multi_modal_embedding = pipeline( task=Tasks.multi_modal_embedding) - text_embedding = pipeline_multi_modal_embedding( + text_embedding = pipeline_multi_modal_embedding.forward( self.test_input)[OutputKeys.TEXT_EMBEDDING] print('l1-norm: {}'.format( torch.norm(text_embedding, p=1, dim=-1).item())) diff --git a/tests/trainers/test_clip_trainer.py b/tests/trainers/test_clip_trainer.py new file mode 100644 index 00000000..e460f1ac --- /dev/null +++ b/tests/trainers/test_clip_trainer.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest + +import json + +from modelscope.metainfo import Metrics, Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestClipTrainer(unittest.TestCase): + + def setUp(self) -> None: + self.finetune_cfg = \ + {'framework': 'pytorch', + 'task': 'multi-modal-embedding', + 'pipeline': {'type': 'multi-modal-embedding'}, + 'pretrained_model': {'model_name': 'damo/multi-modal_clip-vit-base-patch16_zh'}, + 'dataset': {'column_map': {'img': 'image', 'text': 'query'}}, + 'train': {'work_dir': './workspace/ckpts/clip', + # 'launcher': 'pytorch', + 'max_epochs': 1, + 'use_fp16': True, + 'dataloader': {'batch_size_per_gpu': 8, + 'workers_per_gpu': 0, + 'shuffle': True, + 'drop_last': True}, + 'lr_scheduler': {'name': 'cosine', + 'warmup_proportion': 0.01}, + 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, + 'optimizer': {'type': 'AdamW'}, + 'optimizer_hparams': {'lr': 5e-05, 'weight_decay': 0.01}, + 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', + 'cumulative_iters': 1, + 'loss_keys': 'loss'}, + 'loss_cfg': {'aggregate': True}, + 'hooks': [{'type': 'BestCkptSaverHook', + 'metric_key': 'inbatch_t2i_recall_at_1', + 'interval': 100}, + {'type': 'TextLoggerHook', 'interval': 1}, + {'type': 'IterTimerHook'}, + {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}, + {'type': 'ClipClampLogitScaleHook'}]}, + 'evaluation': {'dataloader': {'batch_size_per_gpu': 8, + 'workers_per_gpu': 0, + 'shuffle': True, + 'drop_last': True}, + 'metrics': [{'type': 'inbatch_recall'}]}, + 'preprocessor': []} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_std(self): + WORKSPACE = './workspace/ckpts/clip' + os.makedirs(WORKSPACE, exist_ok=True) + config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + with open(config_file, 'w') as writer: + json.dump(self.finetune_cfg, writer) + + pretrained_model = 'damo/multi-modal_clip-vit-base-patch16_zh' + args = dict( + model=pretrained_model, + work_dir=WORKSPACE, + train_dataset=MsDataset.load( + 'muge', namespace='modelscope', split='train[:200]'), + eval_dataset=MsDataset.load( + 'muge', namespace='modelscope', split='validation[:100]'), + metrics=[Metrics.inbatch_recall], + cfg_file=config_file) + trainer = build_trainer( + name=Trainers.clip_multi_modal_embedding, default_args=args) + trainer.train() + + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, + os.listdir(os.path.join(WORKSPACE, 'output'))) + shutil.rmtree(WORKSPACE) + + +if __name__ == '__main__': + unittest.main()