Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10572842master
@@ -389,6 +389,7 @@ class Preprocessors(object): | |||
# multi-modal preprocessor | |||
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | |||
clip_preprocessor = 'clip-preprocessor' | |||
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | |||
# science preprocessor | |||
@@ -428,6 +429,8 @@ class Metrics(object): | |||
image_inpainting_metric = 'image-inpainting-metric' | |||
# metric for ocr | |||
NED = 'ned' | |||
# metric for cross-modal retrieval | |||
inbatch_recall = 'inbatch_recall' | |||
# metric for referring-video-object-segmentation task | |||
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | |||
@@ -474,6 +477,9 @@ class Hooks(object): | |||
# Compression | |||
SparsityHook = 'SparsityHook' | |||
# CLIP logit_scale clamp | |||
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' | |||
class LR_Schedulers(object): | |||
"""learning rate scheduler is defined here | |||
@@ -24,6 +24,7 @@ class MetricKeys(object): | |||
ROUGE_1 = 'rouge-1' | |||
ROUGE_L = 'rouge-l' | |||
NED = 'ned' # ocr metric | |||
BatchAcc = 'inbatch_t2i_recall_at_1' | |||
task_default_metrics = { | |||
@@ -0,0 +1,55 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Dict | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import Metrics | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.registry import default_group | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
@METRICS.register_module( | |||
group_key=default_group, module_name=Metrics.inbatch_recall) | |||
class InbatchRecallMetric(Metric): | |||
"""The metric computation class for in-batch retrieval classes. | |||
This metric class calculates in-batch image recall@1 for each input batch. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.inbatch_t2i_hitcnts = [] | |||
self.batch_sizes = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
image_features = outputs[OutputKeys.IMG_EMBEDDING] | |||
text_features = outputs[OutputKeys.TEXT_EMBEDDING] | |||
assert type(image_features) == torch.Tensor and type( | |||
text_features) == torch.Tensor | |||
with torch.no_grad(): | |||
logits_per_image = image_features @ text_features.t() | |||
logits_per_text = logits_per_image.t() | |||
batch_size = logits_per_image.shape[0] | |||
ground_truth = torch.arange(batch_size).long() | |||
ground_truth = ground_truth.to(image_features.device) | |||
inbatch_t2i_hitcnt = (logits_per_text.argmax(-1) == ground_truth | |||
).sum().float().item() | |||
self.inbatch_t2i_hitcnts.append(inbatch_t2i_hitcnt) | |||
self.batch_sizes.append(batch_size) | |||
def evaluate(self): | |||
assert len(self.inbatch_t2i_hitcnts) == len( | |||
self.batch_sizes) and len(self.batch_sizes) > 0 | |||
return { | |||
MetricKeys.BatchAcc: | |||
sum(self.inbatch_t2i_hitcnts) / sum(self.batch_sizes) | |||
} |
@@ -15,15 +15,13 @@ | |||
import os | |||
from collections import OrderedDict | |||
from typing import Any, Dict, Iterable, List, Tuple, Union | |||
from typing import Any, Dict, Tuple, Union | |||
import json | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from PIL import Image | |||
from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
@@ -506,21 +504,6 @@ def convert_weights(model: nn.Module): | |||
model.apply(_convert_weights_to_fp16) | |||
def _convert_to_rgb(image): | |||
return image.convert('RGB') | |||
def image_transform(image_size=224): | |||
transform = Compose([ | |||
_convert_to_rgb, | |||
Resize((image_size, image_size)), | |||
ToTensor(), | |||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)), | |||
]) | |||
return transform | |||
@MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | |||
class CLIPForMultiModalEmbedding(TorchModel): | |||
@@ -540,72 +523,40 @@ class CLIPForMultiModalEmbedding(TorchModel): | |||
with open(vision_model_config_file, | |||
'r') as fv, open(text_model_config_file, 'r') as ft: | |||
model_info = json.load(fv) | |||
self.model_info = json.load(fv) | |||
for k, v in json.load(ft).items(): | |||
model_info[k] = v | |||
# image preprocess | |||
self.img_preprocess = image_transform(model_info['image_resolution']) | |||
self.model_info[k] = v | |||
# text tokenizer | |||
vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | |||
self.tokenizer = FullTokenizer(vocab_file=vocab_file) | |||
# initialize the model | |||
self.clip_model = CLIP(**model_info, tokenizer=self.tokenizer) | |||
self.clip_model = CLIP(**self.model_info, tokenizer=self.tokenizer) | |||
convert_weights(self.clip_model) | |||
# restore the pretrained weight | |||
checkpoint = torch.load( | |||
f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu') | |||
sd = checkpoint['state_dict'] | |||
sd = checkpoint[ | |||
'state_dict'] if 'state_dict' in checkpoint else checkpoint | |||
if next(iter(sd.items()))[0].startswith('module'): | |||
sd = {k[len('module.'):]: v for k, v in sd.items()} | |||
# support the finetuned model | |||
if next(iter(sd.items()))[0].startswith('clip_model'): | |||
sd = {k[len('clip_model.'):]: v for k, v in sd.items()} | |||
self.clip_model.load_state_dict(sd) | |||
self.clip_model.eval() | |||
# place the model | |||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
if self.device == 'cuda': | |||
self.device = 'cuda:{}'.format(int(os.environ.get( | |||
'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu' | |||
if torch.cuda.is_available(): | |||
self.clip_model.to(self.device) | |||
logger.info('Use GPU for inference') | |||
logger.info('Use GPU {} for finetuning & inference'.format( | |||
int(os.environ.get('LOCAL_RANK', 0)))) | |||
else: | |||
self.clip_model.float() | |||
logger.info('Use CPU for inference') | |||
def tokenize(self, | |||
texts: Union[str, List[str]], | |||
context_length: int = 52) -> torch.LongTensor: | |||
""" | |||
Returns the tokenized representation of given input string(s) | |||
Parameters | |||
---------- | |||
texts : Union[str, List[str]] | |||
An input string or a list of input strings to tokenize | |||
context_length : int | |||
The context length to use; all baseline models use 24 as the context length | |||
Returns | |||
------- | |||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||
""" | |||
if isinstance(texts, str): | |||
texts = [texts] | |||
all_tokens = [] | |||
for text in texts: | |||
all_tokens.append( | |||
[self.tokenizer.vocab['[CLS]']] | |||
+ self.tokenizer.convert_tokens_to_ids( | |||
self.tokenizer.tokenize(text))[:context_length - 2] | |||
+ [self.tokenizer.vocab['[SEP]']]) | |||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||
for i, tokens in enumerate(all_tokens): | |||
assert len(tokens) <= context_length | |||
result[i, :len(tokens)] = torch.tensor(tokens) | |||
return result | |||
logger.info('Use CPU for finetuning & inference') | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
from modelscope.outputs import OutputKeys | |||
@@ -613,75 +564,36 @@ class CLIPForMultiModalEmbedding(TorchModel): | |||
OutputKeys.IMG_EMBEDDING: None, | |||
OutputKeys.TEXT_EMBEDDING: None | |||
} | |||
if 'img' in input and input['img'] is not None: | |||
image_input = input['img'] | |||
# single image input | |||
if isinstance(image_input, Image.Image): | |||
image_tensor = self.img_preprocess(image_input).unsqueeze(0) | |||
# multi images input | |||
elif isinstance(image_input, list): | |||
if all([isinstance(elem, Image.Image) | |||
for elem in image_input]): | |||
image_tensor = torch.stack( | |||
[self.img_preprocess(elem) for elem in image_input], | |||
dim=0) | |||
else: | |||
unsupported_elem_type = [ | |||
type(elem) for elem in image_input | |||
if not isinstance(elem, Image.Image) | |||
][0] | |||
raise TypeError( | |||
f'img should be PIL.Image or List[PIL.Image], \ | |||
but got a List containing one {unsupported_elem_type}' | |||
) | |||
# others | |||
else: | |||
raise TypeError( | |||
f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' | |||
) | |||
image_tensor = image_tensor.to(self.device) | |||
with torch.no_grad(): | |||
mode = input.get('mode', ModeKeys.INFERENCE) | |||
# encode the image | |||
if 'img' in input and isinstance(input['img'], torch.Tensor): | |||
image_tensor = input['img'].to(self.device) | |||
if image_tensor.dim() == 5 and image_tensor.shape[1] == 1: | |||
image_tensor = image_tensor.squeeze(1) | |||
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): | |||
image_features = self.clip_model.encode_image(image_tensor) | |||
image_features /= image_features.norm( | |||
dim=-1, keepdim=True) # l2-normalize | |||
output[OutputKeys.IMG_EMBEDDING] = image_features | |||
if 'text' in input and input['text'] is not None: | |||
text_input = input['text'] | |||
# single text input | |||
if isinstance(text_input, str): | |||
text_tensor = self.tokenize(text_input) | |||
# multi texts input | |||
elif isinstance(text_input, list): | |||
if all([isinstance(elem, str) for elem in text_input]): | |||
text_tensor = self.tokenize(text_input) | |||
else: | |||
unsupported_elem_type = [ | |||
type(elem) for elem in text_input | |||
if not isinstance(elem, str) | |||
][0] | |||
raise TypeError( | |||
f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' | |||
) | |||
# others | |||
else: | |||
raise TypeError( | |||
f'text should be str or List[str], but got {type(text_input)}' | |||
) | |||
text_tensor = text_tensor.to(self.device) | |||
with torch.no_grad(): | |||
if 'text' in input and isinstance(input['text'], torch.Tensor): | |||
text_tensor = input['text'].to(self.device) | |||
if text_tensor.dim() == 3 and text_tensor.shape[1] == 1: | |||
text_tensor = text_tensor.squeeze(1) | |||
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): | |||
text_features = self.clip_model.encode_text(text_tensor) | |||
text_features /= text_features.norm( | |||
dim=-1, keepdim=True) # l2-normalize | |||
output[OutputKeys.TEXT_EMBEDDING] = text_features | |||
if mode == ModeKeys.TRAIN: | |||
output['logit_scale'] = (self.clip_model.logit_scale | |||
* 1.0).exp().mean() | |||
return output | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
@@ -1,10 +1,12 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
from typing import Any, Dict, Optional, Union | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.multi_modal.clip.model import CLIPForMultiModalEmbedding | |||
from modelscope.pipelines.base import Input, Model, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors.multi_modal import CLIPPreprocessor, Preprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -17,7 +19,10 @@ logger = get_logger() | |||
Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | |||
class MultiModalEmbeddingPipeline(Pipeline): | |||
def __init__(self, model: str, device: str = 'gpu'): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
Args: | |||
@@ -29,14 +34,17 @@ class MultiModalEmbeddingPipeline(Pipeline): | |||
pipe_model = model | |||
else: | |||
raise NotImplementedError('model must be a single str') | |||
pipe_model.eval() | |||
if preprocessor is None: | |||
if isinstance(pipe_model, CLIPForMultiModalEmbedding): | |||
preprocessor = CLIPPreprocessor(pipe_model.model_dir) | |||
else: | |||
raise NotImplementedError | |||
super().__init__(model=pipe_model) | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
return input | |||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
return self.model(input) | |||
return self.model(self.preprocess(input)) | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -3,8 +3,11 @@ import os.path as osp | |||
from io import BytesIO | |||
from typing import Any, Dict, List, Tuple, Union | |||
import json | |||
import torch | |||
from PIL import Image | |||
from timm.data import create_transform | |||
from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Preprocessors | |||
@@ -107,6 +110,180 @@ class OfaPreprocessor(Preprocessor): | |||
eos_idx=self.tokenizer.eos_token_id) | |||
def _convert_to_rgb(image): | |||
return image.convert('RGB') | |||
@PREPROCESSORS.register_module( | |||
Fields.multi_modal, module_name=Preprocessors.clip_preprocessor) | |||
class CLIPPreprocessor(Preprocessor): | |||
def __init__(self, | |||
model_dir: str, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
model_dir (str): model path | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super().__init__(*args, **kwargs) | |||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||
model_dir) | |||
self.mode = mode | |||
# text tokenizer | |||
from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer | |||
if 'tokenizer' in kwargs and isinstance(kwargs['tokenizer'], | |||
FullTokenizer): | |||
self.tokenizer = kwargs['tokenizer'] | |||
else: | |||
vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | |||
self.tokenizer = FullTokenizer(vocab_file=vocab_file) | |||
# image preprocessor | |||
if 'resolution' in kwargs and isinstance(kwargs['resolution'], int): | |||
self.image_resolution = kwargs['resolution'] | |||
else: | |||
self.image_resolution = json.load( | |||
open('{}/vision_model_config.json'.format( | |||
model_dir)))['image_resolution'] | |||
self.img_preprocess = self._build_image_transform() | |||
# key mapping | |||
# specify the input keys, compatible with training and inference whose key names may be different | |||
self.input_keys = {'img': 'img', 'text': 'text'} | |||
def _build_image_transform(self): | |||
if self.mode == ModeKeys.TRAIN: | |||
transform = create_transform( | |||
input_size=self.image_resolution, | |||
scale=(0.9, 1.0), | |||
is_training=True, | |||
color_jitter=None, | |||
auto_augment='original', | |||
interpolation='bicubic', | |||
mean=(0.48145466, 0.4578275, 0.40821073), | |||
std=(0.26862954, 0.26130258, 0.27577711), | |||
) | |||
transform = Compose(transform.transforms[:-3] + [_convert_to_rgb] | |||
+ transform.transforms[-3:]) | |||
else: | |||
transform = Compose([ | |||
Resize((self.image_resolution, self.image_resolution), | |||
interpolation=Image.BICUBIC), | |||
_convert_to_rgb, | |||
ToTensor(), | |||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)), | |||
]) | |||
return transform | |||
def tokenize(self, | |||
texts: Union[str, List[str]], | |||
context_length: int = 52) -> torch.LongTensor: | |||
""" | |||
Returns the tokenized representation of given input string(s) | |||
Parameters | |||
---------- | |||
texts : Union[str, List[str]] | |||
An input string or a list of input strings to tokenize | |||
context_length : int | |||
The context length to use; all baseline models use 24 as the context length | |||
Returns | |||
------- | |||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||
""" | |||
if isinstance(texts, str): | |||
texts = [texts] | |||
all_tokens = [] | |||
for text in texts: | |||
all_tokens.append( | |||
[self.tokenizer.vocab['[CLS]']] | |||
+ self.tokenizer.convert_tokens_to_ids( | |||
self.tokenizer.tokenize(text))[:context_length - 2] | |||
+ [self.tokenizer.vocab['[SEP]']]) | |||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||
for i, tokens in enumerate(all_tokens): | |||
assert len(tokens) <= context_length | |||
result[i, :len(tokens)] = torch.tensor(tokens) | |||
return result | |||
def set_input_img_key(self, new_key: str): | |||
self.input_keys['img'] = new_key | |||
def set_input_text_key(self, new_key: str): | |||
self.input_keys['text'] = new_key | |||
def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, | |||
**kwargs) -> Dict[str, Any]: | |||
output = {} | |||
# preprocess the image input | |||
input_img_key = self.input_keys['img'] | |||
if input_img_key in input and input[input_img_key] is not None: | |||
image_input = input[input_img_key] | |||
# single image input | |||
if isinstance(image_input, Image.Image): | |||
image_tensor = self.img_preprocess(image_input).unsqueeze(0) | |||
# multi images input | |||
elif isinstance(image_input, list): | |||
if all([isinstance(elem, Image.Image) | |||
for elem in image_input]): | |||
image_tensor = torch.stack( | |||
[self.img_preprocess(elem) | |||
for elem in image_input], # noqa | |||
dim=0) # noqa | |||
else: | |||
unsupported_elem_type = [ | |||
type(elem) for elem in image_input | |||
if not isinstance(elem, Image.Image) | |||
][0] | |||
raise TypeError( | |||
f'img should be PIL.Image or List[PIL.Image], \ | |||
but got a List containing one {unsupported_elem_type}' | |||
) | |||
# others | |||
else: | |||
raise TypeError( | |||
f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' | |||
) | |||
output['img'] = image_tensor | |||
# preprocess the text input | |||
input_text_key = self.input_keys['text'] | |||
if input_text_key in input and input[input_text_key] is not None: | |||
text_input = input[input_text_key] | |||
# single text input | |||
if isinstance(text_input, str): | |||
text_tensor = self.tokenize(text_input) | |||
# multi texts input | |||
elif isinstance(text_input, list): | |||
if all([isinstance(elem, str) for elem in text_input]): | |||
text_tensor = self.tokenize(text_input) | |||
else: | |||
unsupported_elem_type = [ | |||
type(elem) for elem in text_input | |||
if not isinstance(elem, str) | |||
][0] | |||
raise TypeError( | |||
f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' | |||
) | |||
# others | |||
else: | |||
raise TypeError( | |||
f'text should be str or List[str], but got {type(text_input)}' | |||
) | |||
output['text'] = text_tensor | |||
return output | |||
@PREPROCESSORS.register_module( | |||
Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | |||
class MPlugPreprocessor(Preprocessor): | |||
@@ -0,0 +1,18 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
from modelscope.metainfo import Hooks | |||
from modelscope.trainers.multi_modal.clip.clip_trainer import CLIPTrainer | |||
from .builder import HOOKS | |||
from .hook import Hook | |||
@HOOKS.register_module(module_name=Hooks.ClipClampLogitScaleHook) | |||
class ClipClampLogitScaleHook(Hook): | |||
"""ClipClampLogitScaleHook hook which performs clamp on CLIP logit scale parameter after update""" | |||
def after_train_iter(self, trainer: CLIPTrainer): | |||
"""Called after every training iter to evaluate the results.""" | |||
unwrapped_model = getattr(trainer.model, 'module', trainer.model) | |||
logit_scale = unwrapped_model.clip_model.logit_scale | |||
logit_scale.data = torch.clamp(logit_scale.data, 0, 4.6052) |
@@ -1,169 +1,206 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
import os | |||
from typing import Dict, Optional | |||
from typing import Callable, Dict, Optional, Tuple, Union | |||
import torch | |||
import torch.distributed as dist | |||
from torch.utils.data import DataLoader | |||
from torch.utils.data.distributed import DistributedSampler | |||
from torch import distributed as dist | |||
from torch import nn | |||
from torch.utils.data import Dataset | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.base import Model | |||
from modelscope.trainers.base import BaseTrainer | |||
from modelscope.models.base import Model, TorchModel | |||
from modelscope.models.multi_modal.clip.model import convert_models_to_fp32 | |||
from modelscope.msdatasets.ms_dataset import MsDataset | |||
from modelscope.preprocessors.base import Preprocessor | |||
from modelscope.preprocessors.multi_modal import CLIPPreprocessor | |||
from modelscope.trainers import EpochBasedTrainer | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.optimizer.builder import build_optimizer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModeKeys | |||
from modelscope.utils.logger import get_logger | |||
from .clip_trainer_utils import ImageWithCaptionDataset, get_optimizer | |||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | |||
ModeKeys) | |||
from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule | |||
logger = get_logger() | |||
def exclude(n): | |||
return 'bn' in n or 'ln' in n or 'bias' in n or 'logit_scale' in n | |||
def include(n): | |||
return not exclude(n) | |||
@TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) | |||
class CLIPTrainer(BaseTrainer): | |||
def __init__(self, cfg_file: str, model: str, device_id: int, *args, | |||
**kwargs): | |||
super().__init__(cfg_file) | |||
self.cfg = Config.from_file(cfg_file) | |||
self.model = Model.from_pretrained(model) | |||
self.device_id = device_id | |||
self.total_epoch = self.cfg.train.epoch | |||
self.train_batch_size = self.cfg.train.batch_size | |||
self.val_batch_size = self.cfg.evaluation.batch_size | |||
self.ckpt_dir = self.cfg.train.ckpt_dir | |||
self.train_dataset = ImageWithCaptionDataset( | |||
json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||
self.cfg.dataset.train_set), | |||
img_dir=self.cfg.dataset.root_dir, | |||
phase=ModeKeys.TRAIN) | |||
self.val_dataset = ImageWithCaptionDataset( | |||
json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||
self.cfg.dataset.val_set), | |||
img_dir=self.cfg.dataset.root_dir, | |||
phase=ModeKeys.EVAL) | |||
def train(self, *args, **kwargs): | |||
assert dist.is_initialized() | |||
self.model.clip_model.train() | |||
self.model.clip_model.to(self.device_id) | |||
ddp_model = torch.nn.parallel.DistributedDataParallel( | |||
self.model.clip_model, device_ids=[ | |||
self.device_id, | |||
]) | |||
optimizer = get_optimizer(ddp_model) | |||
for epoch in range(self.total_epoch): | |||
train_sampler = DistributedSampler( | |||
dataset=self.train_dataset, shuffle=True) | |||
train_sampler.set_epoch(epoch) | |||
train_params = { | |||
'pin_memory': True, | |||
'collate_fn': None, | |||
'batch_size': self.train_batch_size, | |||
'shuffle': False, | |||
'drop_last': True, | |||
'sampler': train_sampler, | |||
'num_workers': 8 | |||
class CLIPTrainer(EpochBasedTrainer): | |||
def __init__( | |||
self, | |||
model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||
cfg_file: Optional[str] = None, | |||
arg_parse_fn: Optional[Callable] = None, | |||
data_collator: Optional[Union[Callable, Dict[str, | |||
Callable]]] = None, | |||
train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
preprocessor: Optional[Union[Preprocessor, | |||
Dict[str, Preprocessor]]] = None, | |||
optimizers: Tuple[torch.optim.Optimizer, | |||
torch.optim.lr_scheduler._LRScheduler] = (None, | |||
None), | |||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
seed: int = 42, | |||
**kwargs): | |||
model = Model.from_pretrained(model, revision=model_revision) | |||
# for training & eval, we convert the model from FP16 back to FP32 | |||
# to compatible with modelscope amp training | |||
convert_models_to_fp32(model) | |||
cfg = Config.from_file(cfg_file) | |||
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: | |||
work_dir = cfg.train.work_dir | |||
else: | |||
work_dir = kwargs['work_dir'] | |||
# fetch the model name of CLIP model (base, large or large-336) | |||
model_name = cfg.pretrained_model.model_name | |||
# world size | |||
world_size = int(os.environ.get('WORLD_SIZE', 1)) | |||
# train step, optimizer and lr_scheduler | |||
epoch_steps = math.ceil( | |||
len(train_dataset) / # noqa | |||
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | |||
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | |||
if optimizers[0] is None: | |||
named_parameters = list(model.named_parameters()) | |||
gain_or_bias_params = [ | |||
p for n, p in named_parameters | |||
if exclude(n) and p.requires_grad | |||
] | |||
rest_params = [ | |||
p for n, p in named_parameters | |||
if include(n) and p.requires_grad | |||
] | |||
optimizer_hparams = get_optimizer_params( | |||
model_name, cfg) # lr, wd, beta1, beta2, eps | |||
optimizer_args = { | |||
'params': [ | |||
{ | |||
'params': gain_or_bias_params, | |||
'weight_decay': 0. | |||
}, | |||
{ | |||
'params': rest_params, | |||
'weight_decay': optimizer_hparams['weight_decay'] | |||
}, | |||
], | |||
'lr': | |||
optimizer_hparams['lr'], | |||
'betas': | |||
(optimizer_hparams['beta1'], optimizer_hparams['beta2']), | |||
'eps': | |||
optimizer_hparams['eps'], | |||
} | |||
optimizer = build_optimizer( | |||
model, cfg=cfg.train.optimizer, default_args=optimizer_args) | |||
else: | |||
optimizer = optimizers[0] | |||
if optimizers[1] is None: | |||
lr_scheduler = get_schedule(optimizer, cfg.train.lr_scheduler) | |||
else: | |||
lr_scheduler = optimizers[1] | |||
optimizers = (optimizer, lr_scheduler) | |||
# loss module | |||
loss_img = nn.CrossEntropyLoss() | |||
loss_txt = nn.CrossEntropyLoss() | |||
self.loss_img = loss_img.cuda(int(os.environ.get('LOCAL_RANK', 0))) | |||
self.loss_txt = loss_txt.cuda(int(os.environ.get('LOCAL_RANK', 0))) | |||
self.loss_cfg = cfg.train.loss_cfg | |||
# launcher and use_fp16 | |||
if 'launcher' not in kwargs and cfg.train.get('launcher', None): | |||
kwargs['launcher'] = cfg.train.launcher | |||
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | |||
kwargs['use_fp16'] = cfg.train.use_fp16 | |||
# preprocessor | |||
if preprocessor is None: | |||
preprocessor = { | |||
ConfigKeys.train: | |||
CLIPPreprocessor( | |||
model_dir=work_dir, | |||
mode=ModeKeys.TRAIN, | |||
tokenizer=model.tokenizer, | |||
resolution=model.model_info['image_resolution']), | |||
ConfigKeys.val: | |||
CLIPPreprocessor( | |||
model_dir=work_dir, | |||
mode=ModeKeys.EVAL, | |||
tokenizer=model.tokenizer, | |||
resolution=model.model_info['image_resolution']), | |||
} | |||
train_loader = DataLoader(self.train_dataset, **train_params) | |||
for batch_idx, (img_tensor, text_str_list, | |||
img_id_list) in enumerate(train_loader): | |||
text_info_list = [ | |||
self.model.tokenize_text(tmp) for tmp in text_str_list | |||
] | |||
text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||
dim=0) | |||
text_masks_tensor = torch.cat( | |||
[tmp[1] for tmp in text_info_list], dim=0) | |||
img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||
text_ids_tensor = text_ids_tensor.to( | |||
self.device_id, non_blocking=True) | |||
text_masks_tensor = text_masks_tensor.to( | |||
self.device_id, non_blocking=True) | |||
loss = ddp_model((img_tensor, text_ids_tensor, | |||
text_masks_tensor, img_id_list), | |||
ModeKeys.TRAIN) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
if batch_idx % 10 == 0: | |||
logger.info( | |||
'epoch: {}, train batch {}/{}, loss={:.5f}, logit_scale={:.5f}' | |||
.format(epoch, batch_idx, len(train_loader), | |||
loss.item(), | |||
ddp_model.module.logit_scale.exp().item())) | |||
if dist.get_rank() == 0: | |||
os.makedirs(self.ckpt_dir, exist_ok=True) | |||
torch.save(ddp_model.module.state_dict(), | |||
'{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||
def evaluate(self, | |||
checkpoint_path: Optional[str] = None, | |||
*args, | |||
**kwargs) -> Dict[str, float]: | |||
if checkpoint_path is not None: | |||
checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||
self.model.clip_model.load_state_dict(checkpoint_params) | |||
self.model.clip_model.eval() | |||
self.model.clip_model.to(self.device_id) | |||
val_params = { | |||
'collate_fn': None, | |||
'batch_size': self.val_batch_size, | |||
'shuffle': False, | |||
'drop_last': False, | |||
'num_workers': 8 | |||
} | |||
val_loader = DataLoader(self.val_dataset, **val_params) | |||
tp_cnt_per_batch = [] | |||
processed_cnt = 0 | |||
with torch.no_grad(): | |||
for batch_idx, (img_tensor, text_str_list, | |||
img_id_list) in enumerate(val_loader): | |||
text_info_list = [ | |||
self.model.tokenize_text(tmp) for tmp in text_str_list | |||
] | |||
text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||
dim=0) | |||
text_masks_tensor = torch.cat( | |||
[tmp[1] for tmp in text_info_list], dim=0) | |||
img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||
text_ids_tensor = text_ids_tensor.to( | |||
self.device_id, non_blocking=True) | |||
text_masks_tensor = text_masks_tensor.to( | |||
self.device_id, non_blocking=True) | |||
img_feat = self.model.clip_model(img_tensor, input_type='img') | |||
text_feat = self.model.clip_model( | |||
(text_ids_tensor, text_masks_tensor), input_type='text') | |||
sim_mat = text_feat @ img_feat.t() | |||
text_cnt, img_cnt = sim_mat.shape | |||
top1_scores, match_ids = torch.max(sim_mat, dim=1) | |||
match_ids = match_ids.int() | |||
gt_ids = torch.tensor(range(0, text_cnt)).to( | |||
self.device_id, non_blocking=True).int() | |||
error_cnt = torch.nonzero(match_ids - gt_ids) | |||
processed_cnt += text_cnt | |||
tp_cnt_per_batch.append(text_cnt - 1.0 * error_cnt.numel()) | |||
logger.info('current acc: {:.3f}'.format( | |||
sum(tp_cnt_per_batch) / processed_cnt)) | |||
# dataset related | |||
self.dataset_cfg = cfg.dataset | |||
if hasattr(self.dataset_cfg, 'column_map'): | |||
# cases where dataset key names are not "img" and "text" | |||
img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img') | |||
preprocessor[ConfigKeys.train].set_input_img_key(img_key_name) | |||
preprocessor[ConfigKeys.val].set_input_img_key(img_key_name) | |||
text_key_name = getattr(self.dataset_cfg.column_map, 'text', | |||
'text') | |||
preprocessor[ConfigKeys.train].set_input_text_key(text_key_name) | |||
preprocessor[ConfigKeys.val].set_input_text_key(text_key_name) | |||
self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size | |||
super().__init__( | |||
model=model, | |||
cfg_file=cfg_file, | |||
arg_parse_fn=arg_parse_fn, | |||
data_collator=data_collator, | |||
train_dataset=train_dataset, | |||
eval_dataset=eval_dataset, | |||
preprocessor=preprocessor, | |||
optimizers=optimizers, | |||
seed=seed, | |||
**kwargs, | |||
) | |||
def train_step(self, model, inputs): | |||
model.train() | |||
inputs['mode'] = ModeKeys.TRAIN | |||
model_outputs = model.forward( | |||
inputs | |||
) # {OutputKeys.IMG_EMBEDDING: Tensor(batch_size, dim), OutputKeys.TEXT_EMBEDDING: Tensor(batch_size, dim)} | |||
loss = get_loss(model_outputs, self.loss_img, self.loss_txt, | |||
self.loss_cfg) | |||
train_outputs = {'loss': loss} | |||
# add model output info to log | |||
if 'log_vars' not in train_outputs: | |||
default_keys_pattern = ['loss'] | |||
match_keys = set([]) | |||
for key_p in default_keys_pattern: | |||
match_keys.update( | |||
[key for key in train_outputs.keys() if key_p in key]) | |||
log_vars = {} | |||
for key in match_keys: | |||
value = train_outputs.get(key, None) | |||
if value is not None: | |||
if dist.is_available() and dist.is_initialized(): | |||
value = value.data.clone() | |||
dist.all_reduce(value.div_(dist.get_world_size())) | |||
log_vars.update({key: value.item()}) | |||
unwrapped_model = getattr(model, 'module', model) | |||
log_vars[ | |||
'logit_scale'] = unwrapped_model.clip_model.logit_scale.data.clone( | |||
).item() # noqa | |||
log_vars['global_batch_size'] = int(self.global_batch_size) | |||
self.log_buffer.update(log_vars) | |||
else: | |||
self.log_buffer.update(train_outputs['log_vars']) | |||
self.train_outputs = train_outputs |
@@ -1,94 +1,125 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2022 The OFA-Sys Team. | |||
# All rights reserved. | |||
# This source code is licensed under the Apache 2.0 license | |||
# found in the LICENSE file in the root directory. | |||
import math | |||
import os | |||
import random | |||
from functools import partial | |||
from inspect import unwrap | |||
import json | |||
import torch | |||
import torch.nn.functional as F | |||
from PIL import Image | |||
from torch.utils.data import Dataset | |||
from torchvision import transforms | |||
from modelscope.utils.constant import ModeKeys | |||
train_transform = transforms.Compose([ | |||
transforms.RandomResizedCrop( | |||
224, scale=(0.5, 1.0), interpolation=Image.BICUBIC), | |||
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], | |||
p=0.8), | |||
transforms.RandomGrayscale(p=0.2), | |||
transforms.RandomHorizontalFlip(), | |||
transforms.ToTensor(), | |||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)) | |||
]) | |||
val_transform = transforms.Compose([ | |||
transforms.Resize((224, 224), interpolation=Image.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)) | |||
]) | |||
class ImageWithCaptionDataset(Dataset): | |||
def __init__(self, json_file, img_dir, phase): | |||
self.annotations = json.load(open(json_file)) | |||
self.img_dir = img_dir | |||
if phase == ModeKeys.TRAIN: | |||
self.transform = train_transform | |||
elif phase == ModeKeys.EVAL: | |||
self.transform = val_transform | |||
self.img_name2img_id = {} | |||
for anno_dict in self.annotations: | |||
img_name = anno_dict['image'] | |||
if img_name not in self.img_name2img_id: | |||
self.img_name2img_id[img_name] = len(self.img_name2img_id) | |||
def __len__(self): | |||
return len(self.annotations) | |||
def __getitem__(self, index): | |||
anno_dict = self.annotations[index] | |||
img_path = os.path.join(self.img_dir, anno_dict['image']) | |||
img_pil = Image.open(img_path).convert('RGB') | |||
img_th = self.transform(img_pil) | |||
img_id = self.img_name2img_id[anno_dict['image']] | |||
text_str = random.choice(anno_dict['caption']) | |||
return img_th, text_str, img_id | |||
def get_params_groups(ddp_model, weight_decay): | |||
decay = [] | |||
no_decay = [] | |||
for name, param in ddp_model.named_parameters(): | |||
if not param.requires_grad: | |||
continue | |||
if len(param.shape) == 1 or name.endswith('.bias'): | |||
no_decay.append(param) | |||
else: | |||
decay.append(param) | |||
params_groups = [{ | |||
'params': no_decay, | |||
'weight_decay': 0. | |||
}, { | |||
'params': decay, | |||
'weight_decay': weight_decay | |||
}] | |||
return params_groups | |||
def get_optimizer(ddp_model): | |||
from torch.optim import AdamW | |||
lr_init = 1e-5 | |||
betas = [0.9, 0.999] | |||
weight_decay = 0.02 | |||
params_groups = get_params_groups(ddp_model, weight_decay=weight_decay) | |||
return AdamW( | |||
params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) | |||
import torch.distributed as dist | |||
from torch.optim.lr_scheduler import LambdaLR | |||
from modelscope.outputs import OutputKeys | |||
def get_optimizer_params(model_name, cfg): | |||
# get default params | |||
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf) | |||
# base model | |||
if model_name in ['damo/multi-modal_clip-vit-base-patch16_zh']: | |||
params = { | |||
'lr': 5.0e-4, | |||
'beta1': 0.9, | |||
'beta2': 0.98, | |||
'eps': 1.0e-6, | |||
'weight_decay': 0.0 | |||
} | |||
# large models | |||
elif model_name in [ | |||
'damo/multi-modal_clip-vit-large-patch14_zh', | |||
'damo/multi-modal_clip-vit-large-patch14_336_zh' | |||
]: | |||
params = { | |||
'lr': 4.0e-4, | |||
'beta1': 0.9, | |||
'beta2': 0.98, | |||
'eps': 1.0e-6, | |||
'weight_decay': 0.0 | |||
} | |||
else: | |||
params = { | |||
'lr': 5.0e-4, | |||
'beta1': 0.9, | |||
'beta2': 0.999, | |||
'eps': 1.0e-8, | |||
'weight_decay': 0.0 | |||
} | |||
# override with config params | |||
for key in ['lr', 'beta1', 'beta2', 'eps', 'weight_decay']: | |||
if hasattr(cfg.train, 'optimizer_hparams'): | |||
params[key] = getattr(cfg.train.optimizer_hparams, key, | |||
params[key]) | |||
return params | |||
def get_loss(model_outputs, loss_img, loss_txt, loss_cfg): | |||
image_features = model_outputs[OutputKeys.IMG_EMBEDDING] | |||
text_features = model_outputs[OutputKeys.TEXT_EMBEDDING] | |||
logit_scale = model_outputs['logit_scale'] | |||
logit_scale = logit_scale.mean() | |||
if loss_cfg.aggregate and int(os.environ.get('WORLD_SIZE', 1)) > 1: | |||
world_size = dist.get_world_size() | |||
rank = dist.get_rank() | |||
# We gather tensors from all gpus to get more negatives to contrast with. | |||
gathered_image_features = [ | |||
torch.zeros_like(image_features) for _ in range(world_size) | |||
] | |||
gathered_text_features = [ | |||
torch.zeros_like(text_features) for _ in range(world_size) | |||
] | |||
dist.all_gather(gathered_image_features, image_features) | |||
dist.all_gather(gathered_text_features, text_features) | |||
all_image_features = torch.cat([image_features] | |||
+ gathered_image_features[:rank] | |||
+ gathered_image_features[rank + 1:]) | |||
all_text_features = torch.cat([text_features] | |||
+ gathered_text_features[:rank] | |||
+ gathered_text_features[rank + 1:]) | |||
# this is needed to send gradients back everywhere. | |||
logits_per_image = logit_scale * all_image_features @ all_text_features.t( | |||
) | |||
logits_per_text = logits_per_image.t() | |||
else: | |||
logits_per_image = logit_scale * image_features @ text_features.t() | |||
logits_per_text = logit_scale * text_features @ image_features.t() | |||
ground_truth = torch.arange(len(logits_per_image)).long() | |||
ground_truth = ground_truth.cuda( | |||
int(os.environ.get('LOCAL_RANK', 0)), non_blocking=True) | |||
total_loss = (loss_img(logits_per_image, ground_truth) | |||
+ loss_txt(logits_per_text, ground_truth)) / 2 | |||
return total_loss | |||
def lr_lambda(num_warmup_steps, num_training_steps, num_cycles, current_step): | |||
if current_step < num_warmup_steps: | |||
return float(current_step) / float(max(1, num_warmup_steps)) | |||
progress = float(current_step - num_warmup_steps) / float( | |||
max(1, num_training_steps - num_warmup_steps)) | |||
return max( | |||
0.0, | |||
0.5 * # noqa | |||
(1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) # noqa | |||
def get_schedule(optimizer, | |||
scheduler, | |||
num_cycles: float = 0.5, | |||
last_epoch: int = -1): | |||
num_warmup_steps = int(scheduler.warmup_proportion | |||
* scheduler.num_train_steps) | |||
num_training_steps = scheduler.num_train_steps | |||
return LambdaLR( | |||
optimizer, | |||
partial(lr_lambda, num_warmup_steps, num_training_steps, num_cycles), | |||
last_epoch) |
@@ -24,7 +24,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def test_run(self): | |||
pipeline_multi_modal_embedding = pipeline( | |||
Tasks.multi_modal_embedding, model=self.model_id) | |||
text_embedding = pipeline_multi_modal_embedding( | |||
text_embedding = pipeline_multi_modal_embedding.forward( | |||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | |||
print('l1-norm: {}'.format( | |||
torch.norm(text_embedding, p=1, dim=-1).item())) | |||
@@ -36,7 +36,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
model = Model.from_pretrained(self.model_id) | |||
pipeline_multi_modal_embedding = pipeline( | |||
task=Tasks.multi_modal_embedding, model=model) | |||
text_embedding = pipeline_multi_modal_embedding( | |||
text_embedding = pipeline_multi_modal_embedding.forward( | |||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | |||
print('l1-norm: {}'.format( | |||
torch.norm(text_embedding, p=1, dim=-1).item())) | |||
@@ -47,7 +47,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def test_run_with_default_model(self): | |||
pipeline_multi_modal_embedding = pipeline( | |||
task=Tasks.multi_modal_embedding) | |||
text_embedding = pipeline_multi_modal_embedding( | |||
text_embedding = pipeline_multi_modal_embedding.forward( | |||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | |||
print('l1-norm: {}'.format( | |||
torch.norm(text_embedding, p=1, dim=-1).item())) | |||
@@ -0,0 +1,83 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import unittest | |||
import json | |||
from modelscope.metainfo import Metrics, Trainers | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.test_utils import test_level | |||
class TestClipTrainer(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.finetune_cfg = \ | |||
{'framework': 'pytorch', | |||
'task': 'multi-modal-embedding', | |||
'pipeline': {'type': 'multi-modal-embedding'}, | |||
'pretrained_model': {'model_name': 'damo/multi-modal_clip-vit-base-patch16_zh'}, | |||
'dataset': {'column_map': {'img': 'image', 'text': 'query'}}, | |||
'train': {'work_dir': './workspace/ckpts/clip', | |||
# 'launcher': 'pytorch', | |||
'max_epochs': 1, | |||
'use_fp16': True, | |||
'dataloader': {'batch_size_per_gpu': 8, | |||
'workers_per_gpu': 0, | |||
'shuffle': True, | |||
'drop_last': True}, | |||
'lr_scheduler': {'name': 'cosine', | |||
'warmup_proportion': 0.01}, | |||
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, | |||
'optimizer': {'type': 'AdamW'}, | |||
'optimizer_hparams': {'lr': 5e-05, 'weight_decay': 0.01}, | |||
'optimizer_hook': {'type': 'TorchAMPOptimizerHook', | |||
'cumulative_iters': 1, | |||
'loss_keys': 'loss'}, | |||
'loss_cfg': {'aggregate': True}, | |||
'hooks': [{'type': 'BestCkptSaverHook', | |||
'metric_key': 'inbatch_t2i_recall_at_1', | |||
'interval': 100}, | |||
{'type': 'TextLoggerHook', 'interval': 1}, | |||
{'type': 'IterTimerHook'}, | |||
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}, | |||
{'type': 'ClipClampLogitScaleHook'}]}, | |||
'evaluation': {'dataloader': {'batch_size_per_gpu': 8, | |||
'workers_per_gpu': 0, | |||
'shuffle': True, | |||
'drop_last': True}, | |||
'metrics': [{'type': 'inbatch_recall'}]}, | |||
'preprocessor': []} | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_std(self): | |||
WORKSPACE = './workspace/ckpts/clip' | |||
os.makedirs(WORKSPACE, exist_ok=True) | |||
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||
with open(config_file, 'w') as writer: | |||
json.dump(self.finetune_cfg, writer) | |||
pretrained_model = 'damo/multi-modal_clip-vit-base-patch16_zh' | |||
args = dict( | |||
model=pretrained_model, | |||
work_dir=WORKSPACE, | |||
train_dataset=MsDataset.load( | |||
'muge', namespace='modelscope', split='train[:200]'), | |||
eval_dataset=MsDataset.load( | |||
'muge', namespace='modelscope', split='validation[:100]'), | |||
metrics=[Metrics.inbatch_recall], | |||
cfg_file=config_file) | |||
trainer = build_trainer( | |||
name=Trainers.clip_multi_modal_embedding, default_args=args) | |||
trainer.train() | |||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | |||
os.listdir(os.path.join(WORKSPACE, 'output'))) | |||
shutil.rmtree(WORKSPACE) | |||
if __name__ == '__main__': | |||
unittest.main() |