post test hang, but all tests have passedmaster
@@ -1,10 +1,10 @@ | |||
# Introduction | |||
ModelScope library is targeted to support training, evaluation and inference for the state of the art models provided by Mind and further support third-party models provided by users outside alibaba. | |||
[ModelScope]( https://www.modelscope.cn) is a “Model-as-a-Service” (MaaS) platform that seeks to bringing together most advanced machine learning models from the AI community, and to streamlining the process of leveraging and applying AI models . The core ModelScope library enables developers to perform model inference, training and evaluation, through rich layers of API designs that facilitate a unified experience across state-of-the-art models from different AI domains. | |||
In order to enable ModelScope users to use the various models provided by ModelScope quickly and conveniently, we provide a set of complete Python library, which includes the implementation of ModelScope official models, inference, finetuning and evaluation support for those models such as preprocessor and evaluation metrics. We also provide easy-to-use APIs and rich usage examples. By calling the library, users can write just a few lines of code to complete tasks such as model inference, training, and evaluation, and can also quickly carry out secondary development on this basis to realize their own innovative ideas. | |||
The Python library offers the layered-APIs necessary for model contributors to integrate models from CV, NLP, Speech, Multi-Modality, as well as Scientific-computation, into the ModelScope ecosystem. Implementations for all these different models are encapsulated within the library in a way that allows easy and unified access. With such integration, model inference, finetuning, and evaluations can be done within only a few lines of codes. In the meantime, flexibilities are provided so that different components in the model applications can be customized as well, where necessary. | |||
At present, the algorithm models provided by library cover four main AI fields of image, natural language processing, speech, and multi-modality, and dozens of application scenarios and tasks. | |||
Apart from harboring implementations of various models, ModelScope library also enables the necessary interactions with the backend services of ModelScope, particularly with the Model-Hub and Dataset-Hub. Such interactions facilitate various entity (models and datasets) management to be performed seamlessly under-the-hood, such as entity lookup, version control, and cache management. | |||
# Installation | |||
@@ -7,9 +7,9 @@ from typing import Any, Dict, Mapping | |||
import torch | |||
from torch import nn | |||
from torch.onnx import export as onnx_export | |||
from torch.onnx.utils import _decide_input_format | |||
from modelscope.models import TorchModel | |||
from modelscope.outputs import ModelOutputBase | |||
from modelscope.pipelines.base import collate_fn | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
@@ -102,6 +102,53 @@ class TorchModelExporter(Exporter): | |||
""" | |||
return None | |||
@staticmethod | |||
def _decide_input_format(model, args): | |||
import inspect | |||
def _signature(model) -> inspect.Signature: | |||
should_be_callable = getattr(model, 'forward', model) | |||
if callable(should_be_callable): | |||
return inspect.signature(should_be_callable) | |||
raise ValueError('model has no forward method and is not callable') | |||
try: | |||
sig = _signature(model) | |||
except ValueError as e: | |||
logger.warn('%s, skipping _decide_input_format' % e) | |||
return args | |||
try: | |||
ordered_list_keys = list(sig.parameters.keys()) | |||
if ordered_list_keys[0] == 'self': | |||
ordered_list_keys = ordered_list_keys[1:] | |||
args_dict: Dict = {} | |||
if isinstance(args, list): | |||
args_list = args | |||
elif isinstance(args, tuple): | |||
args_list = list(args) | |||
else: | |||
args_list = [args] | |||
if isinstance(args_list[-1], Mapping): | |||
args_dict = args_list[-1] | |||
args_list = args_list[:-1] | |||
n_nonkeyword = len(args_list) | |||
for optional_arg in ordered_list_keys[n_nonkeyword:]: | |||
if optional_arg in args_dict: | |||
args_list.append(args_dict[optional_arg]) | |||
# Check if this arg has a default value | |||
else: | |||
param = sig.parameters[optional_arg] | |||
if param.default != param.empty: | |||
args_list.append(param.default) | |||
args = args_list if isinstance(args, list) else tuple(args_list) | |||
# Cases of models with no input args | |||
except IndexError: | |||
logger.warn('No input args, skipping _decide_input_format') | |||
except Exception as e: | |||
logger.warn('Skipping _decide_input_format\n {}'.format(e.args[0])) | |||
return args | |||
def _torch_export_onnx(self, | |||
model: nn.Module, | |||
output: str, | |||
@@ -179,16 +226,21 @@ class TorchModelExporter(Exporter): | |||
with torch.no_grad(): | |||
model.eval() | |||
outputs_origin = model.forward( | |||
*_decide_input_format(model, dummy_inputs)) | |||
if isinstance(outputs_origin, Mapping): | |||
outputs_origin = numpify_tensor_nested( | |||
list(outputs_origin.values())) | |||
*self._decide_input_format(model, dummy_inputs)) | |||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)): | |||
outputs_origin = list( | |||
numpify_tensor_nested(outputs_origin).values()) | |||
elif isinstance(outputs_origin, (tuple, list)): | |||
outputs_origin = numpify_tensor_nested(outputs_origin) | |||
outputs_origin = list(numpify_tensor_nested(outputs_origin)) | |||
outputs = ort_session.run( | |||
onnx_outputs, | |||
numpify_tensor_nested(dummy_inputs), | |||
) | |||
outputs = numpify_tensor_nested(outputs) | |||
if isinstance(outputs, dict): | |||
outputs = list(outputs.values()) | |||
elif isinstance(outputs, tuple): | |||
outputs = list(outputs) | |||
tols = {} | |||
if rtol is not None: | |||
@@ -232,12 +284,25 @@ class TorchModelExporter(Exporter): | |||
'Model property dummy_inputs must be set.') | |||
dummy_inputs = collate_fn(dummy_inputs, device) | |||
if isinstance(dummy_inputs, Mapping): | |||
dummy_inputs = tuple(dummy_inputs.values()) | |||
dummy_inputs_filter = [] | |||
for _input in self._decide_input_format(model, dummy_inputs): | |||
if _input is not None: | |||
dummy_inputs_filter.append(_input) | |||
else: | |||
break | |||
if len(dummy_inputs) != len(dummy_inputs_filter): | |||
logger.warn( | |||
f'Dummy inputs is not continuous in the forward method, ' | |||
f'origin length: {len(dummy_inputs)}, ' | |||
f'the length after filtering: {len(dummy_inputs_filter)}') | |||
dummy_inputs = dummy_inputs_filter | |||
with torch.no_grad(): | |||
model.eval() | |||
with replace_call(): | |||
traced_model = torch.jit.trace( | |||
model, dummy_inputs, strict=strict) | |||
model, tuple(dummy_inputs), strict=strict) | |||
torch.jit.save(traced_model, output) | |||
if validation: | |||
@@ -249,6 +314,10 @@ class TorchModelExporter(Exporter): | |||
outputs = numpify_tensor_nested(outputs) | |||
outputs_origin = model.forward(*dummy_inputs) | |||
outputs_origin = numpify_tensor_nested(outputs_origin) | |||
if isinstance(outputs, dict): | |||
outputs = list(outputs.values()) | |||
if isinstance(outputs_origin, dict): | |||
outputs_origin = list(outputs_origin.values()) | |||
tols = {} | |||
if rtol is not None: | |||
tols['rtol'] = rtol | |||
@@ -23,7 +23,8 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
API_RESPONSE_FIELD_MESSAGE, | |||
API_RESPONSE_FIELD_USERNAME, | |||
DEFAULT_CREDENTIALS_PATH, | |||
MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, | |||
MODELSCOPE_ENVIRONMENT, | |||
MODELSCOPE_USERNAME, ONE_YEAR_SECONDS, | |||
Licenses, ModelVisibility) | |||
from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
NotLoginException, NoValidRevisionError, | |||
@@ -38,8 +39,8 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
DEFAULT_MODEL_REVISION, | |||
DEFAULT_REPOSITORY_REVISION, | |||
MASTER_MODEL_BRANCH, DatasetFormations, | |||
DatasetMetaFormats, DownloadMode, | |||
ModelFile) | |||
DatasetMetaFormats, DownloadChannel, | |||
DownloadMode, ModelFile) | |||
from modelscope.utils.logger import get_logger | |||
from .utils.utils import (get_endpoint, get_release_datetime, | |||
model_id_to_group_owner_name) | |||
@@ -382,10 +383,11 @@ class HubApi: | |||
logger.info('Model revision not specified, use default: %s in development mode' % revision) | |||
if revision not in branches and revision not in tags: | |||
raise NotExistError('The model: %s has no branch or tag : %s .' % revision) | |||
logger.info('Development mode use revision: %s' % revision) | |||
else: | |||
revisions = self.list_model_revisions( | |||
model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) | |||
if revision is None: | |||
if revision is None: # user not specified revision, use latest revision before release time | |||
revisions = self.list_model_revisions( | |||
model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) | |||
if len(revisions) == 0: | |||
raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) | |||
# tags (revisions) returned from backend are guaranteed to be ordered by create-time | |||
@@ -393,9 +395,13 @@ class HubApi: | |||
revision = revisions[0] | |||
logger.info('Model revision not specified, use the latest revision: %s' % revision) | |||
else: | |||
# use user-specified revision | |||
revisions = self.list_model_revisions( | |||
model_id, cutoff_timestamp=current_timestamp, use_cookies=False if cookies is None else cookies) | |||
if revision not in revisions: | |||
raise NotExistError( | |||
'The model: %s has no revision: %s !' % (model_id, revision)) | |||
logger.info('Use user-specified model revision: %s' % revision) | |||
return revision | |||
def get_model_branches_and_tags( | |||
@@ -640,6 +646,25 @@ class HubApi: | |||
def check_local_cookies(self, use_cookies) -> CookieJar: | |||
return self._check_cookie(use_cookies=use_cookies) | |||
def dataset_download_uv(self, dataset_name: str, namespace: str): | |||
if not dataset_name or not namespace: | |||
raise ValueError('dataset_name or namespace cannot be empty!') | |||
# get channel and user_name | |||
channel = DownloadChannel.LOCAL.value | |||
user_name = '' | |||
if MODELSCOPE_ENVIRONMENT in os.environ: | |||
channel = os.environ[MODELSCOPE_ENVIRONMENT] | |||
if MODELSCOPE_USERNAME in os.environ: | |||
user_name = os.environ[MODELSCOPE_USERNAME] | |||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}' | |||
cookies = ModelScopeConfig.get_cookies() | |||
r = requests.post(url, cookies=cookies, headers=self.headers) | |||
resp = r.json() | |||
raise_on_error(resp) | |||
return resp['Message'] | |||
class ModelScopeConfig: | |||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) | |||
@@ -755,14 +780,18 @@ class ModelScopeConfig: | |||
env = 'custom' | |||
if MODELSCOPE_ENVIRONMENT in os.environ: | |||
env = os.environ[MODELSCOPE_ENVIRONMENT] | |||
user_name = 'unknown' | |||
if MODELSCOPE_USERNAME in os.environ: | |||
user_name = os.environ[MODELSCOPE_USERNAME] | |||
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( | |||
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % ( | |||
__version__, | |||
platform.python_version(), | |||
ModelScopeConfig.get_user_session_id(), | |||
platform.platform(), | |||
platform.processor(), | |||
env, | |||
user_name, | |||
) | |||
if isinstance(user_agent, dict): | |||
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
@@ -18,6 +18,7 @@ API_RESPONSE_FIELD_EMAIL = 'Email' | |||
API_RESPONSE_FIELD_MESSAGE = 'Message' | |||
MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | |||
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | |||
MODELSCOPE_USERNAME = 'MODELSCOPE_USERNAME' | |||
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | |||
@@ -5,6 +5,8 @@ import os | |||
from datetime import datetime | |||
from typing import Optional | |||
import requests | |||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||
DEFAULT_MODELSCOPE_GROUP, | |||
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, | |||
@@ -85,3 +87,16 @@ def file_integrity_validation(file_path, expected_sha256): | |||
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path | |||
logger.error(msg) | |||
raise FileIntegrityError(msg) | |||
def create_library_statistics(method: str, name: str, cn_name: Optional[str]): | |||
try: | |||
from modelscope.hub.api import ModelScopeConfig | |||
path = f'{get_endpoint()}/api/v1/statistics/library' | |||
headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||
params = {'Method': method, 'Name': name, 'CnName': cn_name} | |||
r = requests.post(path, params=params, headers=headers) | |||
r.raise_for_status() | |||
except Exception: | |||
pass | |||
return |
@@ -389,6 +389,7 @@ class Preprocessors(object): | |||
# multi-modal preprocessor | |||
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | |||
clip_preprocessor = 'clip-preprocessor' | |||
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | |||
# science preprocessor | |||
@@ -428,6 +429,8 @@ class Metrics(object): | |||
image_inpainting_metric = 'image-inpainting-metric' | |||
# metric for ocr | |||
NED = 'ned' | |||
# metric for cross-modal retrieval | |||
inbatch_recall = 'inbatch_recall' | |||
# metric for referring-video-object-segmentation task | |||
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | |||
@@ -474,6 +477,9 @@ class Hooks(object): | |||
# Compression | |||
SparsityHook = 'SparsityHook' | |||
# CLIP logit_scale clamp | |||
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' | |||
class LR_Schedulers(object): | |||
"""learning rate scheduler is defined here | |||
@@ -24,6 +24,7 @@ class MetricKeys(object): | |||
ROUGE_1 = 'rouge-1' | |||
ROUGE_L = 'rouge-l' | |||
NED = 'ned' # ocr metric | |||
BatchAcc = 'inbatch_t2i_recall_at_1' | |||
task_default_metrics = { | |||
@@ -0,0 +1,55 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Dict | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import Metrics | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.registry import default_group | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
@METRICS.register_module( | |||
group_key=default_group, module_name=Metrics.inbatch_recall) | |||
class InbatchRecallMetric(Metric): | |||
"""The metric computation class for in-batch retrieval classes. | |||
This metric class calculates in-batch image recall@1 for each input batch. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.inbatch_t2i_hitcnts = [] | |||
self.batch_sizes = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
image_features = outputs[OutputKeys.IMG_EMBEDDING] | |||
text_features = outputs[OutputKeys.TEXT_EMBEDDING] | |||
assert type(image_features) == torch.Tensor and type( | |||
text_features) == torch.Tensor | |||
with torch.no_grad(): | |||
logits_per_image = image_features @ text_features.t() | |||
logits_per_text = logits_per_image.t() | |||
batch_size = logits_per_image.shape[0] | |||
ground_truth = torch.arange(batch_size).long() | |||
ground_truth = ground_truth.to(image_features.device) | |||
inbatch_t2i_hitcnt = (logits_per_text.argmax(-1) == ground_truth | |||
).sum().float().item() | |||
self.inbatch_t2i_hitcnts.append(inbatch_t2i_hitcnt) | |||
self.batch_sizes.append(batch_size) | |||
def evaluate(self): | |||
assert len(self.inbatch_t2i_hitcnts) == len( | |||
self.batch_sizes) and len(self.batch_sizes) > 0 | |||
return { | |||
MetricKeys.BatchAcc: | |||
sum(self.inbatch_t2i_hitcnts) / sum(self.batch_sizes) | |||
} |
@@ -131,6 +131,8 @@ class Model(ABC): | |||
if not hasattr(model, 'cfg'): | |||
model.cfg = cfg | |||
model.name = model_name_or_path | |||
return model | |||
def save_pretrained(self, | |||
@@ -161,5 +163,12 @@ class Model(ABC): | |||
assert config is not None, 'Cannot save the model because the model config is empty.' | |||
if isinstance(config, Config): | |||
config = config.to_dict() | |||
if 'preprocessor' in config and config['preprocessor'] is not None: | |||
if 'mode' in config['preprocessor']: | |||
config['preprocessor']['mode'] = 'inference' | |||
elif 'val' in config['preprocessor'] and 'mode' in config[ | |||
'preprocessor']['val']: | |||
config['preprocessor']['val']['mode'] = 'inference' | |||
save_pretrained(self, target_folder, save_checkpoint_names, | |||
save_function, config, **kwargs) |
@@ -15,15 +15,13 @@ | |||
import os | |||
from collections import OrderedDict | |||
from typing import Any, Dict, Iterable, List, Tuple, Union | |||
from typing import Any, Dict, Tuple, Union | |||
import json | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from PIL import Image | |||
from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
@@ -351,11 +349,13 @@ class CLIP(nn.Module): | |||
text_num_hidden_layers: int, | |||
text_type_vocab_size: int, | |||
tokenizer: FullTokenizer, | |||
# vision_head_width, added this param for ViT-H | |||
vision_head_width: int = 64, | |||
): | |||
super().__init__() | |||
if isinstance(vision_layers, (tuple, list)): | |||
vision_heads = vision_width * 32 // 64 | |||
vision_heads = vision_width * 32 // vision_head_width | |||
self.visual = ModifiedResNet( | |||
layers=vision_layers, | |||
output_dim=embed_dim, | |||
@@ -363,7 +363,7 @@ class CLIP(nn.Module): | |||
input_resolution=image_resolution, | |||
width=vision_width) | |||
else: | |||
vision_heads = vision_width // 64 | |||
vision_heads = vision_width // vision_head_width | |||
self.visual = VisualTransformer( | |||
input_resolution=image_resolution, | |||
patch_size=vision_patch_size, | |||
@@ -506,21 +506,6 @@ def convert_weights(model: nn.Module): | |||
model.apply(_convert_weights_to_fp16) | |||
def _convert_to_rgb(image): | |||
return image.convert('RGB') | |||
def image_transform(image_size=224): | |||
transform = Compose([ | |||
_convert_to_rgb, | |||
Resize((image_size, image_size)), | |||
ToTensor(), | |||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)), | |||
]) | |||
return transform | |||
@MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | |||
class CLIPForMultiModalEmbedding(TorchModel): | |||
@@ -540,72 +525,40 @@ class CLIPForMultiModalEmbedding(TorchModel): | |||
with open(vision_model_config_file, | |||
'r') as fv, open(text_model_config_file, 'r') as ft: | |||
model_info = json.load(fv) | |||
self.model_info = json.load(fv) | |||
for k, v in json.load(ft).items(): | |||
model_info[k] = v | |||
# image preprocess | |||
self.img_preprocess = image_transform(model_info['image_resolution']) | |||
self.model_info[k] = v | |||
# text tokenizer | |||
vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | |||
self.tokenizer = FullTokenizer(vocab_file=vocab_file) | |||
# initialize the model | |||
self.clip_model = CLIP(**model_info, tokenizer=self.tokenizer) | |||
self.clip_model = CLIP(**self.model_info, tokenizer=self.tokenizer) | |||
convert_weights(self.clip_model) | |||
# restore the pretrained weight | |||
checkpoint = torch.load( | |||
f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu') | |||
sd = checkpoint['state_dict'] | |||
sd = checkpoint[ | |||
'state_dict'] if 'state_dict' in checkpoint else checkpoint | |||
if next(iter(sd.items()))[0].startswith('module'): | |||
sd = {k[len('module.'):]: v for k, v in sd.items()} | |||
# support the finetuned model | |||
if next(iter(sd.items()))[0].startswith('clip_model'): | |||
sd = {k[len('clip_model.'):]: v for k, v in sd.items()} | |||
self.clip_model.load_state_dict(sd) | |||
self.clip_model.eval() | |||
# place the model | |||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
if self.device == 'cuda': | |||
self.device = 'cuda:{}'.format(int(os.environ.get( | |||
'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu' | |||
if torch.cuda.is_available(): | |||
self.clip_model.to(self.device) | |||
logger.info('Use GPU for inference') | |||
logger.info('Use GPU {} for finetuning & inference'.format( | |||
int(os.environ.get('LOCAL_RANK', 0)))) | |||
else: | |||
self.clip_model.float() | |||
logger.info('Use CPU for inference') | |||
def tokenize(self, | |||
texts: Union[str, List[str]], | |||
context_length: int = 52) -> torch.LongTensor: | |||
""" | |||
Returns the tokenized representation of given input string(s) | |||
Parameters | |||
---------- | |||
texts : Union[str, List[str]] | |||
An input string or a list of input strings to tokenize | |||
context_length : int | |||
The context length to use; all baseline models use 24 as the context length | |||
Returns | |||
------- | |||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||
""" | |||
if isinstance(texts, str): | |||
texts = [texts] | |||
all_tokens = [] | |||
for text in texts: | |||
all_tokens.append( | |||
[self.tokenizer.vocab['[CLS]']] | |||
+ self.tokenizer.convert_tokens_to_ids( | |||
self.tokenizer.tokenize(text))[:context_length - 2] | |||
+ [self.tokenizer.vocab['[SEP]']]) | |||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||
for i, tokens in enumerate(all_tokens): | |||
assert len(tokens) <= context_length | |||
result[i, :len(tokens)] = torch.tensor(tokens) | |||
return result | |||
logger.info('Use CPU for finetuning & inference') | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
from modelscope.outputs import OutputKeys | |||
@@ -613,75 +566,36 @@ class CLIPForMultiModalEmbedding(TorchModel): | |||
OutputKeys.IMG_EMBEDDING: None, | |||
OutputKeys.TEXT_EMBEDDING: None | |||
} | |||
if 'img' in input and input['img'] is not None: | |||
image_input = input['img'] | |||
# single image input | |||
if isinstance(image_input, Image.Image): | |||
image_tensor = self.img_preprocess(image_input).unsqueeze(0) | |||
# multi images input | |||
elif isinstance(image_input, list): | |||
if all([isinstance(elem, Image.Image) | |||
for elem in image_input]): | |||
image_tensor = torch.stack( | |||
[self.img_preprocess(elem) for elem in image_input], | |||
dim=0) | |||
else: | |||
unsupported_elem_type = [ | |||
type(elem) for elem in image_input | |||
if not isinstance(elem, Image.Image) | |||
][0] | |||
raise TypeError( | |||
f'img should be PIL.Image or List[PIL.Image], \ | |||
but got a List containing one {unsupported_elem_type}' | |||
) | |||
# others | |||
else: | |||
raise TypeError( | |||
f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' | |||
) | |||
image_tensor = image_tensor.to(self.device) | |||
with torch.no_grad(): | |||
mode = input.get('mode', ModeKeys.INFERENCE) | |||
# encode the image | |||
if 'img' in input and isinstance(input['img'], torch.Tensor): | |||
image_tensor = input['img'].to(self.device) | |||
if image_tensor.dim() == 5 and image_tensor.shape[1] == 1: | |||
image_tensor = image_tensor.squeeze(1) | |||
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): | |||
image_features = self.clip_model.encode_image(image_tensor) | |||
image_features /= image_features.norm( | |||
dim=-1, keepdim=True) # l2-normalize | |||
output[OutputKeys.IMG_EMBEDDING] = image_features | |||
if 'text' in input and input['text'] is not None: | |||
text_input = input['text'] | |||
# single text input | |||
if isinstance(text_input, str): | |||
text_tensor = self.tokenize(text_input) | |||
# multi texts input | |||
elif isinstance(text_input, list): | |||
if all([isinstance(elem, str) for elem in text_input]): | |||
text_tensor = self.tokenize(text_input) | |||
else: | |||
unsupported_elem_type = [ | |||
type(elem) for elem in text_input | |||
if not isinstance(elem, str) | |||
][0] | |||
raise TypeError( | |||
f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' | |||
) | |||
# others | |||
else: | |||
raise TypeError( | |||
f'text should be str or List[str], but got {type(text_input)}' | |||
) | |||
text_tensor = text_tensor.to(self.device) | |||
with torch.no_grad(): | |||
if 'text' in input and isinstance(input['text'], torch.Tensor): | |||
text_tensor = input['text'].to(self.device) | |||
if text_tensor.dim() == 3 and text_tensor.shape[1] == 1: | |||
text_tensor = text_tensor.squeeze(1) | |||
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): | |||
text_features = self.clip_model.encode_text(text_tensor) | |||
text_features /= text_features.norm( | |||
dim=-1, keepdim=True) # l2-normalize | |||
output[OutputKeys.TEXT_EMBEDDING] = text_features | |||
if mode == ModeKeys.TRAIN: | |||
output['logit_scale'] = (self.clip_model.logit_scale | |||
* 1.0).exp().mean() | |||
return output | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
@@ -136,6 +136,12 @@ class OFAConfig(PretrainedConfig): | |||
entangle_position_embedding=False, | |||
interpolate_position=False, | |||
orig_patch_image_size=224, | |||
share_attn_bias=False, | |||
use_image_feature=True, | |||
disable_entangle=False, | |||
use_ofasys=False, | |||
vit_type='vit_base', | |||
vit_drop_path_rate=0.0, | |||
**kwargs): | |||
self.vocab_size = vocab_size | |||
self.max_position_embeddings = max_position_embeddings | |||
@@ -178,6 +184,13 @@ class OFAConfig(PretrainedConfig): | |||
self.interpolate_position = interpolate_position | |||
self.orig_patch_image_size = orig_patch_image_size | |||
self.share_attn_bias = share_attn_bias | |||
self.use_image_feature = use_image_feature | |||
self.disable_entangle = disable_entangle | |||
self.use_ofasys = use_ofasys | |||
self.vit_type = vit_type | |||
self.vit_drop_path_rate = vit_drop_path_rate | |||
super().__init__( | |||
pad_token_id=pad_token_id, | |||
bos_token_id=bos_token_id, | |||
@@ -35,6 +35,8 @@ from transformers.utils import logging | |||
from .configuration_ofa import OFAConfig | |||
from .generate import utils | |||
from .resnet import ResNet | |||
from .utils.utils import DropPath | |||
from .vit import vit_base, vit_huge, vit_large, vit_large_336 | |||
logger = logging.get_logger(__name__) | |||
@@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList): | |||
yield m | |||
def drop_path(x, drop_prob: float = 0.0, training: bool = False): | |||
r""" | |||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
Args: | |||
x (`nn.Modules`): input nn layers. | |||
drop_prob (`float`): drop path ratio. | |||
training (`bool`): whether is training or inference. | |||
""" | |||
if drop_prob == 0.0 or not training: | |||
return x | |||
keep_prob = 1 - drop_prob | |||
shape = (1, x.shape[1], 1) | |||
random_tensor = keep_prob + torch.rand( | |||
shape, dtype=x.dtype, device=x.device) | |||
random_tensor.floor_() # binarize | |||
output = x.div(keep_prob) * random_tensor | |||
return output | |||
class DropPath(nn.Module): | |||
r""" | |||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
Args: | |||
drop_prob: drop path ratio. | |||
""" | |||
def __init__(self, drop_prob=None): | |||
super().__init__() | |||
self.drop_prob = drop_prob | |||
def forward(self, x): | |||
return drop_path(x, self.drop_prob, self.training) | |||
def extra_repr(self) -> str: | |||
return 'p={}'.format(self.drop_prob) | |||
class OFAAttention(nn.Module): | |||
r""" | |||
Multi-headed attention, with additional implementation for NormFormer. | |||
@@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel): | |||
self.padding_idx) | |||
if config.add_type_embedding: | |||
self.type_embedding = Embedding(2, embed_dim, padding_idx=None) | |||
if config.use_image_feature: | |||
self.type_embedding = Embedding(2, embed_dim, padding_idx=None) | |||
else: | |||
self.type_embedding = Embedding(1, embed_dim, padding_idx=None) | |||
else: | |||
self.type_embedding = None | |||
if config.resnet_type == 'resnet18': | |||
self.embed_images = ResNet( | |||
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) | |||
elif config.resnet_type == 'resnet34': | |||
self.embed_images = ResNet( | |||
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) | |||
elif config.resnet_type == 'resnet50': | |||
self.embed_images = ResNet( | |||
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) | |||
elif config.resnet_type == 'resnet101': | |||
self.embed_images = ResNet( | |||
[3, 4, 23], drop_path_rate=config.resnet_drop_path_rate) | |||
elif config.resnet_type == 'resnet152': | |||
self.embed_images = ResNet( | |||
[3, 8, 36], drop_path_rate=config.resnet_drop_path_rate) | |||
else: | |||
raise NotImplementedError | |||
if config.use_image_feature: | |||
if config.use_ofasys: | |||
vit_backbone = { | |||
'vit_base': vit_base, | |||
'vit_large': vit_large, | |||
'vit_large_336': vit_large_336, | |||
'vit_huge': vit_huge, | |||
}[config.vit_type] | |||
self.embed_images = vit_backbone(config.vit_drop_path_rate) | |||
self.image_proj = Linear(1024, embed_dim) | |||
self.image_proj = Linear(self.embed_images.width, embed_dim) | |||
if config.resnet_model_path: | |||
else: | |||
if config.resnet_type == 'resnet18': | |||
self.embed_images = ResNet( | |||
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) | |||
elif config.resnet_type == 'resnet34': | |||
self.embed_images = ResNet( | |||
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) | |||
elif config.resnet_type == 'resnet50': | |||
self.embed_images = ResNet( | |||
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) | |||
elif config.resnet_type == 'resnet101': | |||
self.embed_images = ResNet( | |||
[3, 4, 23], | |||
drop_path_rate=config.resnet_drop_path_rate) | |||
elif config.resnet_type == 'resnet152': | |||
self.embed_images = ResNet( | |||
[3, 8, 36], | |||
drop_path_rate=config.resnet_drop_path_rate) | |||
else: | |||
raise NotImplementedError | |||
self.image_proj = Linear(1024, embed_dim) | |||
if not config.use_ofasys and config.resnet_model_path: | |||
print('load resnet {}'.format(config.resnet_model_path)) | |||
resnet_state_dict = torch.load(config.resnet_model_path) | |||
self.embed_images.load_state_dict(resnet_state_dict) | |||
@@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel): | |||
self.embed_positions = Embedding(self.max_source_positions + 2, | |||
embed_dim) | |||
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, | |||
embed_dim) | |||
self.pos_ln = LayerNorm(embed_dim) | |||
self.image_pos_ln = LayerNorm(embed_dim) | |||
if config.use_image_feature: | |||
self.embed_image_positions = Embedding( | |||
config.image_bucket_size**2 + 1, embed_dim) | |||
if not config.use_ofasys: | |||
self.pos_ln = LayerNorm(embed_dim) | |||
if config.use_image_feature: | |||
self.image_pos_ln = LayerNorm(embed_dim) | |||
self.pos_scaling = float(embed_dim / self.num_attention_heads | |||
* config.attn_scale_factor)**-0.5 | |||
self.pos_q_linear = nn.Linear(embed_dim, embed_dim) | |||
self.pos_k_linear = nn.Linear(embed_dim, embed_dim) | |||
if not (config.use_ofasys and config.entangle_position_embedding): | |||
self.pos_q_linear = nn.Linear(embed_dim, embed_dim) | |||
self.pos_k_linear = nn.Linear(embed_dim, embed_dim) | |||
if self.encoder_layerdrop > 0.0: | |||
self.layers = LayerDropModuleList(p=self.encoder_layerdrop) | |||
@@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel): | |||
self.token_bucket_size = config.token_bucket_size | |||
token_num_rel_dis = 2 * config.token_bucket_size - 1 | |||
token_rp_bucket = make_token_bucket_position(config.token_bucket_size) | |||
self.share_attn_bias = config.share_attn_bias | |||
num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers | |||
self.token_rel_pos_table_list = nn.ModuleList([ | |||
Embedding( | |||
token_num_rel_dis, self.num_attention_heads, zero_init=True) | |||
for _ in range(config.encoder_layers) | |||
for _ in range(num_rel_pos_tables) | |||
]) | |||
self.image_bucket_size = config.image_bucket_size | |||
image_num_rel_dis = (2 * config.image_bucket_size | |||
- 1) * (2 * config.image_bucket_size - 1) + 3 | |||
image_rp_bucket = make_image_bucket_position(config.image_bucket_size, | |||
image_num_rel_dis) | |||
self.image_rel_pos_table_list = nn.ModuleList([ | |||
Embedding( | |||
image_num_rel_dis, self.num_attention_heads, zero_init=True) | |||
for _ in range(config.encoder_layers) | |||
]) | |||
if config.use_image_feature: | |||
self.image_bucket_size = config.image_bucket_size | |||
image_num_rel_dis = (2 * config.image_bucket_size | |||
- 1) * (2 * config.image_bucket_size - 1) + 3 | |||
image_rp_bucket = make_image_bucket_position( | |||
config.image_bucket_size, image_num_rel_dis) | |||
self.image_rel_pos_table_list = nn.ModuleList([ | |||
Embedding( | |||
image_num_rel_dis, | |||
self.num_attention_heads, | |||
zero_init=True) for _ in range(num_rel_pos_tables) | |||
]) | |||
self.register_buffer('image_rp_bucket', image_rp_bucket) | |||
if config.layernorm_embedding: | |||
self.layernorm_embedding = LayerNorm(embed_dim) | |||
@@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel): | |||
self.layernorm_embedding = None | |||
self.register_buffer('token_rp_bucket', token_rp_bucket) | |||
self.register_buffer('image_rp_bucket', image_rp_bucket) | |||
self.entangle_position_embedding = config.entangle_position_embedding | |||
self.gradient_checkpointing = False | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
self.use_ofasys = config.use_ofasys | |||
def get_input_embeddings(self): | |||
r""" | |||
@@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel): | |||
if has_pads: | |||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) | |||
pos_embed = self.pos_ln(pos_embed) | |||
if patch_images is not None: | |||
image_pos_embed = self.image_pos_ln(image_pos_embed) | |||
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) | |||
if patch_images_2 is not None: | |||
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) | |||
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) | |||
if self.use_ofasys: | |||
if patch_images is not None: | |||
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) | |||
if patch_images_2 is not None: | |||
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) | |||
else: | |||
pos_embed = self.pos_ln(pos_embed) | |||
if patch_images is not None: | |||
image_pos_embed = self.image_pos_ln(image_pos_embed) | |||
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) | |||
if patch_images_2 is not None: | |||
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) | |||
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) | |||
def build_abs_pos_bias(pos_embed): | |||
batch_size, seq_length = pos_embed.size(0), pos_embed.size(1) | |||
if not (self.use_ofasys and self.entangle_position_embedding): | |||
pos_q = self.pos_q_linear(pos_embed).view( | |||
batch_size, seq_length, self.num_attention_heads, | |||
-1).transpose(1, 2) * self.pos_scaling | |||
pos_k = self.pos_k_linear(pos_embed).view( | |||
batch_size, seq_length, self.num_attention_heads, | |||
-1).transpose(1, 2) | |||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
else: | |||
abs_pos_bias = torch.zeros( | |||
batch_size, | |||
self.num_attention_heads, | |||
seq_length, | |||
seq_length, | |||
dtype=pos_embed.dtype, | |||
device=pos_embed.device) | |||
return abs_pos_bias | |||
pos_q = self.pos_q_linear(pos_embed).view( | |||
x.size(0), x.size(1), self.num_attention_heads, -1).transpose( | |||
1, 2) * self.pos_scaling | |||
pos_k = self.pos_k_linear(pos_embed).view( | |||
x.size(0), x.size(1), self.num_attention_heads, | |||
-1).transpose(1, 2) | |||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
abs_pos_bias = build_abs_pos_bias(pos_embed) | |||
# expand attention_mask | |||
if has_pads: | |||
@@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel): | |||
if output_hidden_states: | |||
encoder_states += (x, ) | |||
self_attn_bias = abs_pos_bias.clone() | |||
real_idx = 0 if self.share_attn_bias else idx | |||
self_attn_bias[:, :, -input_ids.size(1):, | |||
-input_ids.size(1):] += self.get_rel_pos_bias( | |||
input_ids, idx) | |||
input_ids, real_idx) | |||
if patch_images_2 is not None: | |||
self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ | |||
self.get_image_rel_pos_bias(image_position_ids_2, idx) | |||
self.get_image_rel_pos_bias(image_position_ids_2, real_idx) | |||
self_attn_bias[:, :, | |||
image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa | |||
image_num_patches_2:image_num_patches_2 + image_num_patches] += \ | |||
self.get_image_rel_pos_bias(image_position_ids, idx) # noqa | |||
self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa | |||
elif patch_images is not None: | |||
self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \ | |||
self.get_image_rel_pos_bias(image_position_ids, idx) | |||
self.get_image_rel_pos_bias(image_position_ids, real_idx) | |||
self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1)) | |||
hidden_outputs = layer( | |||
@@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel): | |||
self._future_mask = torch.empty(0) | |||
self.share_input_output_embed = config.share_decoder_input_output_embed | |||
self.num_attention_heads = config.decoder_attention_heads | |||
self.use_ofasys = config.use_ofasys | |||
self.disable_entangle = config.disable_entangle | |||
if embed_tokens is not None: | |||
self.embed_tokens = embed_tokens | |||
@@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel): | |||
else: | |||
self.layernorm_embedding = None | |||
if config.use_ofasys: | |||
if config.add_type_embedding: | |||
self.type_embedding = Embedding( | |||
1, self.embed_dim, padding_idx=None) | |||
else: | |||
self.type_embedding = None | |||
self.window_size = config.code_image_size // 8 | |||
self.embed_positions = Embedding(self.max_target_positions + 2, | |||
self.embed_dim) | |||
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, | |||
self.embed_dim) | |||
self.pos_ln = LayerNorm(self.embed_dim) | |||
self.image_pos_ln = LayerNorm(self.embed_dim) | |||
if not config.use_ofasys: | |||
self.embed_image_positions = Embedding( | |||
config.image_bucket_size**2 + 1, self.embed_dim) | |||
if not config.use_ofasys: | |||
self.pos_ln = LayerNorm(self.embed_dim) | |||
self.image_pos_ln = LayerNorm(self.embed_dim) | |||
self.pos_scaling = float(self.embed_dim / self.num_attention_heads | |||
* config.attn_scale_factor)**-0.5 | |||
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
if not (config.use_ofasys and config.entangle_position_embedding): | |||
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) | |||
@@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel): | |||
self.token_bucket_size = config.token_bucket_size | |||
token_num_rel_dis = 2 * config.token_bucket_size - 1 | |||
token_rp_bucket = make_token_bucket_position(config.token_bucket_size) | |||
self.share_attn_bias = config.share_attn_bias | |||
num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers | |||
self.token_rel_pos_table_list = nn.ModuleList([ | |||
Embedding( | |||
token_num_rel_dis, self.num_attention_heads, zero_init=True) | |||
for _ in range(config.decoder_layers) | |||
for _ in range(num_rel_pos_tables) | |||
]) | |||
self.image_bucket_size = config.image_bucket_size | |||
image_num_rel_dis = (2 * config.image_bucket_size | |||
- 1) * (2 * config.image_bucket_size - 1) + 3 | |||
image_rp_bucket = make_image_bucket_position(config.image_bucket_size, | |||
image_num_rel_dis) | |||
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ | |||
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa | |||
image_position_idx = torch.cat( | |||
[torch.tensor([0]), image_position_idx.view(-1)]) | |||
image_position_idx = torch.cat( | |||
[image_position_idx, | |||
torch.tensor([1024] * 768)]) | |||
self.image_rel_pos_table_list = nn.ModuleList([ | |||
Embedding( | |||
image_num_rel_dis, self.num_attention_heads, zero_init=True) | |||
for _ in range(config.decoder_layers) | |||
]) | |||
if config.use_image_feature: | |||
if not config.use_ofasys: | |||
self.image_bucket_size = config.image_bucket_size | |||
image_num_rel_dis = (2 * config.image_bucket_size - 1) * ( | |||
2 * config.image_bucket_size - 1) + 3 | |||
image_rp_bucket = make_image_bucket_position( | |||
config.image_bucket_size, image_num_rel_dis) | |||
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ | |||
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa | |||
image_position_idx = torch.cat( | |||
[torch.tensor([0]), | |||
image_position_idx.view(-1)]) | |||
image_position_idx = torch.cat( | |||
[image_position_idx, | |||
torch.tensor([1024] * 768)]) | |||
self.register_buffer('image_position_idx', image_position_idx) | |||
self.image_rel_pos_table_list = nn.ModuleList([ | |||
Embedding( | |||
image_num_rel_dis, | |||
self.num_attention_heads, | |||
zero_init=True) for _ in range(num_rel_pos_tables) | |||
]) | |||
self.register_buffer('image_rp_bucket', image_rp_bucket) | |||
self.register_buffer('token_rp_bucket', token_rp_bucket) | |||
self.register_buffer('image_rp_bucket', image_rp_bucket) | |||
self.register_buffer('image_position_idx', image_position_idx) | |||
self.entangle_position_embedding = config.entangle_position_embedding | |||
self.gradient_checkpointing = False | |||
@@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel): | |||
batch_size = tgt_pos_embed.size(0) | |||
tgt_len = tgt_pos_embed.size(1) | |||
tgt_pos_embed = self.image_pos_ln( | |||
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) | |||
if not self.use_ofasys: | |||
tgt_pos_embed = self.image_pos_ln( | |||
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) | |||
if src_pos_embed is not None: | |||
src_len = src_pos_embed.size(1) | |||
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( | |||
batch_size, tgt_len, self.num_attention_heads, -1).transpose( | |||
1, 2) * self.pos_scaling | |||
pos_k = self.cross_pos_k_linear(src_pos_embed).view( | |||
batch_size, src_len, self.num_attention_heads, | |||
-1).transpose(1, 2) | |||
if not (self.entangle_position_embedding and self.use_ofasys): | |||
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( | |||
batch_size, tgt_len, self.num_attention_heads, | |||
-1).transpose(1, 2) * self.pos_scaling | |||
pos_k = self.cross_pos_k_linear(src_pos_embed).view( | |||
batch_size, src_len, self.num_attention_heads, | |||
-1).transpose(1, 2) | |||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
else: | |||
abs_pos_bias = torch.zeros( | |||
batch_size, | |||
self.num_attention_heads, | |||
tgt_len, | |||
src_len, | |||
dtype=tgt_pos_embed.dtype, | |||
device=tgt_pos_embed.device) | |||
else: | |||
src_len = tgt_pos_embed.size(1) | |||
pos_q = self.self_pos_q_linear(tgt_pos_embed).view( | |||
batch_size, tgt_len, self.num_attention_heads, -1).transpose( | |||
1, 2) * self.pos_scaling | |||
pos_k = self.self_pos_k_linear(tgt_pos_embed).view( | |||
batch_size, src_len, self.num_attention_heads, | |||
-1).transpose(1, 2) | |||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
# batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1) | |||
if not (self.entangle_position_embedding and self.use_ofasys): | |||
pos_q = self.self_pos_q_linear(tgt_pos_embed).view( | |||
batch_size, tgt_len, self.num_attention_heads, | |||
-1).transpose(1, 2) * self.pos_scaling | |||
pos_k = self.self_pos_k_linear(tgt_pos_embed).view( | |||
batch_size, tgt_len, self.num_attention_heads, | |||
-1).transpose(1, 2) | |||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) | |||
else: | |||
abs_pos_bias = torch.zeros( | |||
batch_size, | |||
self.num_attention_heads, | |||
tgt_len, | |||
tgt_len, | |||
dtype=tgt_pos_embed.dtype, | |||
device=tgt_pos_embed.device) | |||
return abs_pos_bias | |||
@@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel): | |||
past_key_values) > 0 else None | |||
self_attn_bias = self_abs_pos_bias.clone() | |||
real_idx = 0 if self.share_attn_bias else idx | |||
if code_masks is None or not code_masks.any(): | |||
self_attn_bias += self.get_rel_pos_bias( | |||
all_prev_output_tokens, idx).unsqueeze(0) | |||
all_prev_output_tokens, real_idx).unsqueeze(0) | |||
elif code_masks is not None and code_masks.all(): | |||
self_attn_bias += self.get_image_rel_pos_bias( | |||
all_prev_output_tokens, idx).unsqueeze(0) | |||
all_prev_output_tokens, real_idx).unsqueeze(0) | |||
else: | |||
self_attn_bias[~code_masks] += self.get_rel_pos_bias( | |||
all_prev_output_tokens, idx).unsqueeze(0) | |||
all_prev_output_tokens, real_idx).unsqueeze(0) | |||
self_attn_bias[code_masks] += self.get_image_rel_pos_bias( | |||
all_prev_output_tokens, idx).unsqueeze(0) | |||
all_prev_output_tokens, real_idx).unsqueeze(0) | |||
self_attn_bias = self_attn_bias.reshape( | |||
-1, | |||
*self_attn_bias.size()[-2:]) | |||
@@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel): | |||
self.encoder = OFAEncoder(config, shared) | |||
self.decoder = OFADecoder(config, shared) | |||
self.use_ofasys = config.use_ofasys | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
@@ -2,6 +2,7 @@ | |||
from typing import Optional | |||
import torch | |||
import torch.nn as nn | |||
def expand_mask(mask: torch.Tensor, | |||
@@ -17,3 +18,42 @@ def expand_mask(mask: torch.Tensor, | |||
src_len).to(dtype) | |||
return expanded_mask.masked_fill(expanded_mask.bool(), | |||
torch.finfo(dtype).min) | |||
def drop_path(x, drop_prob: float = 0.0, training: bool = False): | |||
r""" | |||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
Args: | |||
x (`nn.Modules`): input nn layers. | |||
drop_prob (`float`): drop path ratio. | |||
training (`bool`): whether is training or inference. | |||
""" | |||
if drop_prob == 0.0 or not training: | |||
return x | |||
keep_prob = 1 - drop_prob | |||
shape = (1, x.shape[1], 1) | |||
random_tensor = keep_prob + torch.rand( | |||
shape, dtype=x.dtype, device=x.device) | |||
random_tensor.floor_() # binarize | |||
output = x.div(keep_prob) * random_tensor | |||
return output | |||
class DropPath(nn.Module): | |||
r""" | |||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
Args: | |||
drop_prob: drop path ratio. | |||
""" | |||
def __init__(self, drop_prob=None): | |||
super().__init__() | |||
self.drop_prob = drop_prob | |||
def forward(self, x): | |||
return drop_path(x, self.drop_prob, self.training) | |||
def extra_repr(self) -> str: | |||
return 'p={}'.format(self.drop_prob) |
@@ -0,0 +1,155 @@ | |||
from collections import OrderedDict | |||
import torch | |||
import torch.nn.functional as F | |||
from fairseq.modules import LayerNorm | |||
from torch import nn | |||
from .utils.utils import DropPath | |||
__all__ = [ | |||
'vit_base', | |||
'vit_large', | |||
'vit_large_336', | |||
'vit_huge', | |||
] | |||
class QuickGELU(nn.Module): | |||
def forward(self, x: torch.Tensor): | |||
return x * torch.sigmoid(1.702 * x) | |||
class ResidualAttentionBlock(nn.Module): | |||
def __init__(self, | |||
d_model: int, | |||
n_head: int, | |||
attn_mask: torch.Tensor = None, | |||
drop_path_rate=0.0): | |||
super().__init__() | |||
self.attn = nn.MultiheadAttention(d_model, n_head) | |||
self.ln_1 = LayerNorm(d_model) | |||
self.mlp = nn.Sequential( | |||
OrderedDict([ | |||
('c_fc', nn.Linear(d_model, d_model * 4)), | |||
('gelu', QuickGELU()), | |||
('c_proj', nn.Linear(d_model * 4, d_model)), | |||
])) | |||
self.ln_2 = LayerNorm(d_model) | |||
self.attn_mask = attn_mask | |||
self.drop_path = DropPath(drop_path_rate) | |||
def attention(self, x: torch.Tensor): | |||
self.attn_mask = ( | |||
self.attn_mask.to(dtype=x.dtype, device=x.device) | |||
if self.attn_mask is not None else None) | |||
return self.attn( | |||
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |||
def forward(self, x: torch.Tensor): | |||
x = x + self.drop_path(self.attention(self.ln_1(x))) | |||
x = x + self.drop_path(self.mlp(self.ln_2(x))) | |||
return x | |||
class Transformer(nn.Module): | |||
def __init__( | |||
self, | |||
width: int, | |||
layers: int, | |||
heads: int, | |||
attn_mask: torch.Tensor = None, | |||
drop_path_rate: float = 0.0, | |||
): | |||
super().__init__() | |||
self.width = width | |||
self.layers = layers | |||
self.resblocks = nn.Sequential(*[ | |||
ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate) | |||
for _ in range(layers) | |||
]) | |||
def forward(self, x: torch.Tensor): | |||
return self.resblocks(x) | |||
class VisionTransformer(nn.Module): | |||
def __init__( | |||
self, | |||
input_resolution: int, | |||
patch_size: int, | |||
width: int, | |||
layers: int, | |||
heads: int, | |||
drop_path_rate: float = 0.0, | |||
): | |||
super().__init__() | |||
self.input_resolution = input_resolution | |||
self.patch_size = patch_size | |||
self.conv1 = nn.Conv2d( | |||
in_channels=3, | |||
out_channels=width, | |||
kernel_size=patch_size, | |||
stride=patch_size, | |||
bias=False, | |||
) | |||
scale = width**-0.5 | |||
self.width = width | |||
self.positional_embedding = nn.Parameter(scale * torch.randn( | |||
(input_resolution // patch_size)**2 + 1, width)) | |||
self.ln_pre = LayerNorm(width) | |||
self.transformer = Transformer( | |||
width, layers, heads, drop_path_rate=drop_path_rate) | |||
def forward(self, x: torch.Tensor): | |||
resolution = x.shape[-2] | |||
height, width = x.shape[-2] // self.patch_size, x.shape[ | |||
-1] // self.patch_size | |||
x = self.conv1(x) # shape = [*, width, grid, grid] | |||
x = x.reshape(x.shape[0], x.shape[1], | |||
-1) # shape = [*, width, grid ** 2] | |||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
if resolution != self.input_resolution: | |||
old_pe = self.positional_embedding[1:] | |||
patch_num = self.input_resolution // self.patch_size | |||
old_pe = old_pe.reshape(1, patch_num, patch_num, | |||
-1).permute(0, 3, 1, 2) | |||
new_pe = F.interpolate( | |||
old_pe, size=(height, width), mode='bilinear') | |||
new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1) | |||
x = x + new_pe.to(x.dtype) | |||
else: | |||
x = x + self.positional_embedding[1:].to(x.dtype) | |||
x = self.ln_pre(x) | |||
x = x.permute(1, 0, 2) # NLD -> LND | |||
x = self.transformer(x) | |||
x = x.permute(1, 0, 2) # LND -> NLD | |||
bz, seq, hidden = x.shape | |||
x = x.transpose(1, 2).reshape(bz, hidden, height, width) | |||
return x | |||
def vit_base(drop_path_rate: float = 0.0): | |||
return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate) | |||
def vit_large(drop_path_rate: float = 0.0): | |||
return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate) | |||
def vit_large_336(drop_path_rate: float = 0.0): | |||
return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate) | |||
def vit_huge(drop_path_rate: float = 0.0): | |||
return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate) |
@@ -1,6 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
import os | |||
import re | |||
import string | |||
from functools import partial | |||
from os import path as osp | |||
@@ -53,8 +54,11 @@ class OfaForAllTasks(TorchModel): | |||
raise NotImplementedError | |||
# there is some diff between here and our ofa code, | |||
# there will be no need to use param: use_bpe | |||
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
if not model.use_ofasys: | |||
self.tokenizer.add_tokens( | |||
['<code_{}>'.format(i) for i in range(8192)]) | |||
self.tokenizer.add_tokens( | |||
['<bin_{}>'.format(i) for i in range(1000)]) | |||
self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||
self.batch_size = self.cfg.model.get('batch_size', 1) | |||
self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | |||
@@ -107,6 +111,8 @@ class OfaForAllTasks(TorchModel): | |||
Tasks.text_classification: inference_d[self.gen_type], | |||
Tasks.image_classification: inference_d[self.gen_type], | |||
} | |||
pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' | |||
self.pattern = re.compile(pattern_str) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
input = move_to_device(input, self.model.device) | |||
@@ -132,8 +138,18 @@ class OfaForAllTasks(TorchModel): | |||
caption = input[OutputKeys.CAPTION] | |||
result_l = list() | |||
for cap in caption: | |||
result_l.append(cap.translate(self.transtab).strip()) | |||
if self.language == 'en': | |||
result_l.append(cap.translate(self.transtab).strip()) | |||
else: | |||
result_l.append(cap) | |||
input[OutputKeys.CAPTION] = result_l | |||
if self.gen_type == 'generation' and self.language in [ | |||
'zh', 'cn' | |||
] and self.cfg.task != Tasks.visual_grounding: | |||
ret_l = list() | |||
for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]: | |||
ret_l.append(self.detokenizer(text)) | |||
input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l | |||
return input | |||
def _text_gen_inference(self, input): | |||
@@ -311,3 +327,6 @@ class OfaForAllTasks(TorchModel): | |||
save_function=partial(save_function, with_meta=False), | |||
config=config, | |||
**kwargs) | |||
def detokenizer(self, text): | |||
return self.pattern.sub('', text) |
@@ -36,6 +36,7 @@ class BertForTextRanking(BertForSequenceClassification): | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
*args, | |||
**kwargs) -> AttentionTextClassificationModelOutput: | |||
outputs = self.base_model.forward( | |||
input_ids=input_ids, | |||
@@ -109,6 +109,7 @@ class SbertForSequenceClassification(SbertPreTrainedModel): | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
*args, | |||
**kwargs): | |||
r""" | |||
Args: | |||
@@ -1,3 +1,6 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import copy | |||
import logging | |||
import os | |||
@@ -1,3 +1,6 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import argparse | |||
import os | |||
from typing import Any | |||
@@ -0,0 +1,3 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
"""Unifold Modules.""" |
@@ -274,6 +274,8 @@ class MsDataset: | |||
try: | |||
api.on_dataset_download( | |||
dataset_name=download_dataset, namespace=namespace) | |||
api.dataset_download_uv( | |||
dataset_name=download_dataset, namespace=namespace) | |||
except Exception as e: | |||
logger.error(e) | |||
@@ -69,11 +69,23 @@ TASK_OUTPUTS = { | |||
# face 2d keypoint result for single sample | |||
# { | |||
# "keypoints": [ | |||
# [x1, y1]*106 | |||
# [[x, y]*106], | |||
# [[x, y]*106], | |||
# [[x, y]*106], | |||
# ], | |||
# "poses": [pitch, roll, yaw] | |||
# "poses": [ | |||
# [pitch, roll, yaw], | |||
# [pitch, roll, yaw], | |||
# [pitch, roll, yaw], | |||
# ], | |||
# "boxes": [ | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# ] | |||
# } | |||
Tasks.face_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.POSES], | |||
Tasks.face_2d_keypoints: | |||
[OutputKeys.KEYPOINTS, OutputKeys.POSES, OutputKeys.BOXES], | |||
# face detection result for single sample | |||
# { | |||
@@ -479,17 +491,8 @@ TASK_OUTPUTS = { | |||
# word segmentation result for single sample | |||
# { | |||
# "output": "今天 天气 不错 , 适合 出去 游玩" | |||
# "labels": [ | |||
# {'word': '今天', 'label': 'PROPN'}, | |||
# {'word': '天气', 'label': 'PROPN'}, | |||
# {'word': '不错', 'label': 'VERB'}, | |||
# {'word': ',', 'label': 'NUM'}, | |||
# {'word': '适合', 'label': 'NOUN'}, | |||
# {'word': '出去', 'label': 'PART'}, | |||
# {'word': '游玩', 'label': 'ADV'}, | |||
# ] | |||
# } | |||
Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], | |||
Tasks.word_segmentation: [OutputKeys.OUTPUT], | |||
# TODO @wenmeng.zwm support list of result check | |||
# named entity recognition result for single sample | |||
@@ -699,8 +702,9 @@ TASK_OUTPUTS = { | |||
# "text_embedding": np.array with shape [1, D], | |||
# "caption": "this is an image caption text." | |||
# } | |||
Tasks.generative_multi_modal_embedding: | |||
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION], | |||
Tasks.generative_multi_modal_embedding: [ | |||
OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION | |||
], | |||
# multi-modal similarity result for single sample | |||
# { | |||
@@ -10,6 +10,7 @@ from typing import Any, Dict, Generator, List, Mapping, Union | |||
import numpy as np | |||
from modelscope.hub.utils.utils import create_library_statistics | |||
from modelscope.models.base import Model | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.outputs import TASK_OUTPUTS | |||
@@ -151,7 +152,9 @@ class Pipeline(ABC): | |||
**kwargs) -> Union[Dict[str, Any], Generator]: | |||
# model provider should leave it as it is | |||
# modelscope library developer will handle this function | |||
for single_model in self.models: | |||
if hasattr(single_model, 'name'): | |||
create_library_statistics('pipeline', single_model.name, None) | |||
# place model to cpu or gpu | |||
if (self.model or (self.has_multiple_models and self.models[0])): | |||
if not self._model_prepare: | |||
@@ -93,9 +93,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/cv_resnet50_live-category'), | |||
Tasks.video_category: (Pipelines.video_category, | |||
'damo/cv_resnet50_video-category'), | |||
Tasks.multi_modal_embedding: | |||
(Pipelines.multi_modal_embedding, | |||
'damo/multi-modal_clip-vit-large-patch14_zh'), | |||
Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding, | |||
'damo/multi-modal_clip-vit-base-patch16_zh'), | |||
Tasks.generative_multi_modal_embedding: | |||
(Pipelines.generative_multi_modal_embedding, | |||
'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' | |||
@@ -1,12 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import copy | |||
import math | |||
from typing import Any | |||
import cv2 | |||
import numpy as np | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .base import EasyCVPipeline | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.face_2d_keypoints, module_name=Pipelines.face_2d_keypoints) | |||
@@ -29,18 +39,206 @@ class Face2DKeypointsPipeline(EasyCVPipeline): | |||
*args, | |||
**kwargs) | |||
# face detect pipeline | |||
det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
self.face_detection = pipeline( | |||
Tasks.face_detection, model=det_model_id) | |||
def show_result(self, img, points, scale=2, save_path=None): | |||
return self.predict_op.show_result(img, points, scale, save_path) | |||
def _choose_face(self, det_result, min_face=10): | |||
""" | |||
choose face with maximum area | |||
Args: | |||
det_result: output of face detection pipeline | |||
min_face: minimum size of valid face w/h | |||
""" | |||
bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||
if bboxes.shape[0] == 0: | |||
logger.warn('No face detected!') | |||
return None | |||
# face idx with enough size | |||
face_idx = [] | |||
for i in range(bboxes.shape[0]): | |||
box = bboxes[i] | |||
if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||
face_idx += [i] | |||
if len(face_idx) == 0: | |||
logger.warn( | |||
f'Face size not enough, less than {min_face}x{min_face}!') | |||
return None | |||
bboxes = bboxes[face_idx] | |||
landmarks = landmarks[face_idx] | |||
return bboxes, landmarks | |||
def expend_box(self, box, w, h, scalex=0.3, scaley=0.5): | |||
x1 = box[0] | |||
y1 = box[1] | |||
wb = box[2] - x1 | |||
hb = box[3] - y1 | |||
deltax = int(wb * scalex) | |||
deltay1 = int(hb * scaley) | |||
deltay2 = int(hb * scalex) | |||
x1 = x1 - deltax | |||
y1 = y1 - deltay1 | |||
if x1 < 0: | |||
deltax = deltax + x1 | |||
x1 = 0 | |||
if y1 < 0: | |||
deltay1 = deltay1 + y1 | |||
y1 = 0 | |||
x2 = x1 + wb + 2 * deltax | |||
y2 = y1 + hb + deltay1 + deltay2 | |||
x2 = np.clip(x2, 0, w - 1) | |||
y2 = np.clip(y2, 0, h - 1) | |||
return [x1, y1, x2, y2] | |||
def rotate_point(self, angle, center, landmark): | |||
rad = angle * np.pi / 180.0 | |||
alpha = np.cos(rad) | |||
beta = np.sin(rad) | |||
M = np.zeros((2, 3), dtype=np.float32) | |||
M[0, 0] = alpha | |||
M[0, 1] = beta | |||
M[0, 2] = (1 - alpha) * center[0] - beta * center[1] | |||
M[1, 0] = -beta | |||
M[1, 1] = alpha | |||
M[1, 2] = beta * center[0] + (1 - alpha) * center[1] | |||
landmark_ = np.asarray([(M[0, 0] * x + M[0, 1] * y + M[0, 2], | |||
M[1, 0] * x + M[1, 1] * y + M[1, 2]) | |||
for (x, y) in landmark]) | |||
return M, landmark_ | |||
def rotate_crop_img(self, img, pts, M): | |||
imgT = cv2.warpAffine(img, M, (int(img.shape[1]), int(img.shape[0]))) | |||
x1 = pts[5][0] | |||
x2 = pts[5][0] | |||
y1 = pts[5][1] | |||
y2 = pts[5][1] | |||
for i in range(0, 9): | |||
x1 = min(x1, pts[i][0]) | |||
x2 = max(x2, pts[i][0]) | |||
y1 = min(y1, pts[i][1]) | |||
y2 = max(y2, pts[i][1]) | |||
height, width, _ = imgT.shape | |||
x1 = min(max(0, int(x1)), width) | |||
y1 = min(max(0, int(y1)), height) | |||
x2 = min(max(0, int(x2)), width) | |||
y2 = min(max(0, int(y2)), height) | |||
sub_imgT = imgT[y1:y2, x1:x2] | |||
return sub_imgT, imgT, [x1, y1, x2, y2] | |||
def crop_img(self, imgT, pts): | |||
enlarge_ratio = 1.1 | |||
x1 = np.min(pts[:, 0]) | |||
x2 = np.max(pts[:, 0]) | |||
y1 = np.min(pts[:, 1]) | |||
y2 = np.max(pts[:, 1]) | |||
w = x2 - x1 + 1 | |||
h = y2 - y1 + 1 | |||
x1 = int(x1 - (enlarge_ratio - 1.0) / 2.0 * w) | |||
y1 = int(y1 - (enlarge_ratio - 1.0) / 2.0 * h) | |||
x1 = max(0, x1) | |||
y1 = max(0, y1) | |||
new_w = int(enlarge_ratio * w) | |||
new_h = int(enlarge_ratio * h) | |||
new_x1 = x1 | |||
new_y1 = y1 | |||
new_x2 = new_x1 + new_w | |||
new_y2 = new_y1 + new_h | |||
height, width, _ = imgT.shape | |||
new_x1 = min(max(0, new_x1), width) | |||
new_y1 = min(max(0, new_y1), height) | |||
new_x2 = max(min(width, new_x2), 0) | |||
new_y2 = max(min(height, new_y2), 0) | |||
sub_imgT = imgT[new_y1:new_y2, new_x1:new_x2] | |||
return sub_imgT, [new_x1, new_y1, new_x2, new_y2] | |||
def __call__(self, inputs) -> Any: | |||
outputs = self.predict_op(inputs) | |||
img = LoadImage.convert_to_ndarray(inputs) | |||
h, w, c = img.shape | |||
img_rgb = copy.deepcopy(img) | |||
img_rgb = img_rgb[:, :, ::-1] | |||
det_result = self.face_detection(img_rgb) | |||
bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
if bboxes.shape[0] == 0: | |||
logger.warn('No face detected!') | |||
results = { | |||
OutputKeys.KEYPOINTS: [], | |||
OutputKeys.POSES: [], | |||
OutputKeys.BOXES: [] | |||
} | |||
return results | |||
boxes, keypoints = self._choose_face(det_result) | |||
output_boxes = [] | |||
output_keypoints = [] | |||
output_poses = [] | |||
for index, box_ori in enumerate(boxes): | |||
box = self.expend_box(box_ori, w, h, scalex=0.1, scaley=0.1) | |||
y0 = int(box[1]) | |||
y1 = int(box[3]) | |||
x0 = int(box[0]) | |||
x1 = int(box[2]) | |||
sub_img = img[y0:y1, x0:x1] | |||
keypoint = keypoints[index] | |||
pts = [[keypoint[0], keypoint[1]], [keypoint[2], keypoint[3]], | |||
[keypoint[4], keypoint[5]], [keypoint[6], keypoint[7]], | |||
[keypoint[8], keypoint[9]], [box[0], box[1]], | |||
[box[2], box[1]], [box[0], box[3]], [box[2], box[3]]] | |||
# radian | |||
angle = math.atan2((pts[1][1] - pts[0][1]), | |||
(pts[1][0] - pts[0][0])) | |||
# angle | |||
theta = angle * (180 / np.pi) | |||
center = [w // 2, h // 2] | |||
cx, cy = center | |||
M, landmark_ = self.rotate_point(theta, (cx, cy), pts) | |||
sub_imgT, imgT, bbox = self.rotate_crop_img(img, landmark_, M) | |||
outputs = self.predict_op([sub_imgT])[0] | |||
tmp_keypoints = outputs['point'] | |||
for idx in range(0, len(tmp_keypoints)): | |||
tmp_keypoints[idx][0] += bbox[0] | |||
tmp_keypoints[idx][1] += bbox[1] | |||
for idx in range(0, 6): | |||
sub_img, bbox = self.crop_img(imgT, tmp_keypoints) | |||
outputs = self.predict_op([sub_img])[0] | |||
tmp_keypoints = outputs['point'] | |||
for idx in range(0, len(tmp_keypoints)): | |||
tmp_keypoints[idx][0] += bbox[0] | |||
tmp_keypoints[idx][1] += bbox[1] | |||
M2, tmp_keypoints = self.rotate_point(-theta, (cx, cy), | |||
tmp_keypoints) | |||
results = [{ | |||
OutputKeys.KEYPOINTS: output['point'], | |||
OutputKeys.POSES: output['pose'] | |||
} for output in outputs] | |||
output_keypoints.append(np.array(tmp_keypoints)) | |||
output_poses.append(np.array(outputs['pose'])) | |||
output_boxes.append(np.array(box_ori)) | |||
if self._is_single_inputs(inputs): | |||
results = results[0] | |||
results = { | |||
OutputKeys.KEYPOINTS: output_keypoints, | |||
OutputKeys.POSES: output_poses, | |||
OutputKeys.BOXES: output_boxes | |||
} | |||
return results |
@@ -1,10 +1,12 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
from typing import Any, Dict, Optional, Union | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.multi_modal.clip.model import CLIPForMultiModalEmbedding | |||
from modelscope.pipelines.base import Input, Model, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors.multi_modal import CLIPPreprocessor, Preprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -17,7 +19,10 @@ logger = get_logger() | |||
Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | |||
class MultiModalEmbeddingPipeline(Pipeline): | |||
def __init__(self, model: str, device: str = 'gpu'): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
""" | |||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||
Args: | |||
@@ -29,14 +34,17 @@ class MultiModalEmbeddingPipeline(Pipeline): | |||
pipe_model = model | |||
else: | |||
raise NotImplementedError('model must be a single str') | |||
pipe_model.eval() | |||
if preprocessor is None: | |||
if isinstance(pipe_model, CLIPForMultiModalEmbedding): | |||
preprocessor = CLIPPreprocessor(pipe_model.model_dir) | |||
else: | |||
raise NotImplementedError | |||
super().__init__(model=pipe_model) | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
return input | |||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
return self.model(input) | |||
return self.model(self.preprocess(input)) | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -109,13 +109,13 @@ class TokenClassificationPipeline(Pipeline): | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
# for cws output | |||
# for cws outputs | |||
if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
spans = [ | |||
chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
] | |||
seg_result = ' '.join(spans) | |||
outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||
outputs = {OutputKeys.OUTPUT: seg_result} | |||
# for ner outputs | |||
else: | |||
@@ -115,15 +115,15 @@ class WordSegmentationPipeline(Pipeline): | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
# for cws output | |||
# for cws outputs | |||
if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
spans = [ | |||
chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
] | |||
seg_result = ' '.join(spans) | |||
outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||
outputs = {OutputKeys.OUTPUT: seg_result} | |||
# for ner outpus | |||
# for ner output | |||
else: | |||
outputs = {OutputKeys.OUTPUT: chunks} | |||
return outputs |
@@ -3,8 +3,11 @@ import os.path as osp | |||
from io import BytesIO | |||
from typing import Any, Dict, List, Tuple, Union | |||
import json | |||
import torch | |||
from PIL import Image | |||
from timm.data import create_transform | |||
from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Preprocessors | |||
@@ -74,7 +77,7 @@ class OfaPreprocessor(Preprocessor): | |||
data[key] = item | |||
return data | |||
def _ofa_input_compatibility_conversion(self, data): | |||
def _ofa_input_compatibility_conversion(self, data): # fake | |||
if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | |||
if isinstance(data['image'], str): | |||
image = load_image(data['image']) | |||
@@ -93,7 +96,6 @@ class OfaPreprocessor(Preprocessor): | |||
data = input | |||
else: | |||
data = self._build_dict(input) | |||
data = self._ofa_input_compatibility_conversion(data) | |||
sample = self.preprocess(data) | |||
str_data = dict() | |||
for k, v in data.items(): | |||
@@ -107,6 +109,180 @@ class OfaPreprocessor(Preprocessor): | |||
eos_idx=self.tokenizer.eos_token_id) | |||
def _convert_to_rgb(image): | |||
return image.convert('RGB') | |||
@PREPROCESSORS.register_module( | |||
Fields.multi_modal, module_name=Preprocessors.clip_preprocessor) | |||
class CLIPPreprocessor(Preprocessor): | |||
def __init__(self, | |||
model_dir: str, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
model_dir (str): model path | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super().__init__(*args, **kwargs) | |||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||
model_dir) | |||
self.mode = mode | |||
# text tokenizer | |||
from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer | |||
if 'tokenizer' in kwargs and isinstance(kwargs['tokenizer'], | |||
FullTokenizer): | |||
self.tokenizer = kwargs['tokenizer'] | |||
else: | |||
vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' | |||
self.tokenizer = FullTokenizer(vocab_file=vocab_file) | |||
# image preprocessor | |||
if 'resolution' in kwargs and isinstance(kwargs['resolution'], int): | |||
self.image_resolution = kwargs['resolution'] | |||
else: | |||
self.image_resolution = json.load( | |||
open('{}/vision_model_config.json'.format( | |||
model_dir)))['image_resolution'] | |||
self.img_preprocess = self._build_image_transform() | |||
# key mapping | |||
# specify the input keys, compatible with training and inference whose key names may be different | |||
self.input_keys = {'img': 'img', 'text': 'text'} | |||
def _build_image_transform(self): | |||
if self.mode == ModeKeys.TRAIN: | |||
transform = create_transform( | |||
input_size=self.image_resolution, | |||
scale=(0.9, 1.0), | |||
is_training=True, | |||
color_jitter=None, | |||
auto_augment='original', | |||
interpolation='bicubic', | |||
mean=(0.48145466, 0.4578275, 0.40821073), | |||
std=(0.26862954, 0.26130258, 0.27577711), | |||
) | |||
transform = Compose(transform.transforms[:-3] + [_convert_to_rgb] | |||
+ transform.transforms[-3:]) | |||
else: | |||
transform = Compose([ | |||
Resize((self.image_resolution, self.image_resolution), | |||
interpolation=Image.BICUBIC), | |||
_convert_to_rgb, | |||
ToTensor(), | |||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)), | |||
]) | |||
return transform | |||
def tokenize(self, | |||
texts: Union[str, List[str]], | |||
context_length: int = 52) -> torch.LongTensor: | |||
""" | |||
Returns the tokenized representation of given input string(s) | |||
Parameters | |||
---------- | |||
texts : Union[str, List[str]] | |||
An input string or a list of input strings to tokenize | |||
context_length : int | |||
The context length to use; all baseline models use 24 as the context length | |||
Returns | |||
------- | |||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||
""" | |||
if isinstance(texts, str): | |||
texts = [texts] | |||
all_tokens = [] | |||
for text in texts: | |||
all_tokens.append( | |||
[self.tokenizer.vocab['[CLS]']] | |||
+ self.tokenizer.convert_tokens_to_ids( | |||
self.tokenizer.tokenize(text))[:context_length - 2] | |||
+ [self.tokenizer.vocab['[SEP]']]) | |||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||
for i, tokens in enumerate(all_tokens): | |||
assert len(tokens) <= context_length | |||
result[i, :len(tokens)] = torch.tensor(tokens) | |||
return result | |||
def set_input_img_key(self, new_key: str): | |||
self.input_keys['img'] = new_key | |||
def set_input_text_key(self, new_key: str): | |||
self.input_keys['text'] = new_key | |||
def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, | |||
**kwargs) -> Dict[str, Any]: | |||
output = {} | |||
# preprocess the image input | |||
input_img_key = self.input_keys['img'] | |||
if input_img_key in input and input[input_img_key] is not None: | |||
image_input = input[input_img_key] | |||
# single image input | |||
if isinstance(image_input, Image.Image): | |||
image_tensor = self.img_preprocess(image_input).unsqueeze(0) | |||
# multi images input | |||
elif isinstance(image_input, list): | |||
if all([isinstance(elem, Image.Image) | |||
for elem in image_input]): | |||
image_tensor = torch.stack( | |||
[self.img_preprocess(elem) | |||
for elem in image_input], # noqa | |||
dim=0) # noqa | |||
else: | |||
unsupported_elem_type = [ | |||
type(elem) for elem in image_input | |||
if not isinstance(elem, Image.Image) | |||
][0] | |||
raise TypeError( | |||
f'img should be PIL.Image or List[PIL.Image], \ | |||
but got a List containing one {unsupported_elem_type}' | |||
) | |||
# others | |||
else: | |||
raise TypeError( | |||
f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' | |||
) | |||
output['img'] = image_tensor | |||
# preprocess the text input | |||
input_text_key = self.input_keys['text'] | |||
if input_text_key in input and input[input_text_key] is not None: | |||
text_input = input[input_text_key] | |||
# single text input | |||
if isinstance(text_input, str): | |||
text_tensor = self.tokenize(text_input) | |||
# multi texts input | |||
elif isinstance(text_input, list): | |||
if all([isinstance(elem, str) for elem in text_input]): | |||
text_tensor = self.tokenize(text_input) | |||
else: | |||
unsupported_elem_type = [ | |||
type(elem) for elem in text_input | |||
if not isinstance(elem, str) | |||
][0] | |||
raise TypeError( | |||
f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' | |||
) | |||
# others | |||
else: | |||
raise TypeError( | |||
f'text should be str or List[str], but got {type(text_input)}' | |||
) | |||
output['text'] = text_tensor | |||
return output | |||
@PREPROCESSORS.register_module( | |||
Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | |||
class MPlugPreprocessor(Preprocessor): | |||
@@ -34,6 +34,7 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
label=None, | |||
label2id=None, | |||
mode=ModeKeys.INFERENCE, | |||
use_fast=None, | |||
**kwargs): | |||
"""The NLP preprocessor base class. | |||
@@ -45,14 +46,18 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping | |||
if this mapping is not supplied. | |||
mode: Run this preprocessor in either 'train'/'eval'/'inference' mode | |||
use_fast: use the fast version of tokenizer | |||
""" | |||
self.model_dir = model_dir | |||
self.first_sequence = first_sequence | |||
self.second_sequence = second_sequence | |||
self.label = label | |||
self.use_fast = kwargs.pop('use_fast', None) | |||
if self.use_fast is None and os.path.isfile( | |||
self.use_fast = use_fast | |||
if self.use_fast is None and model_dir is None: | |||
self.use_fast = False | |||
elif self.use_fast is None and os.path.isfile( | |||
os.path.join(model_dir, 'tokenizer_config.json')): | |||
with open(os.path.join(model_dir, 'tokenizer_config.json'), | |||
'r') as f: | |||
@@ -61,8 +66,8 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
self.use_fast = False if self.use_fast is None else self.use_fast | |||
self.label2id = label2id | |||
if self.label2id is None: | |||
self.label2id = parse_label_mapping(self.model_dir) | |||
if self.label2id is None and model_dir is not None: | |||
self.label2id = parse_label_mapping(model_dir) | |||
super().__init__(mode, **kwargs) | |||
@property | |||
@@ -106,6 +111,7 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): | |||
label: str = 'label', | |||
label2id: dict = None, | |||
mode: str = ModeKeys.INFERENCE, | |||
use_fast: bool = None, | |||
**kwargs): | |||
"""The NLP tokenizer preprocessor base class. | |||
@@ -122,11 +128,12 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): | |||
- config.json label2id/id2label | |||
- label_mapping.json | |||
mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different. | |||
use_fast: use the fast version of tokenizer | |||
kwargs: These kwargs will be directly fed into the tokenizer. | |||
""" | |||
super().__init__(model_dir, first_sequence, second_sequence, label, | |||
label2id, mode) | |||
label2id, mode, use_fast, **kwargs) | |||
self.model_dir = model_dir | |||
self.tokenize_kwargs = kwargs | |||
self.tokenizer = self.build_tokenizer(model_dir) | |||
@@ -2,6 +2,7 @@ | |||
from typing import Any, Dict, Tuple, Union | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import Preprocessors | |||
@@ -20,9 +21,7 @@ class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor): | |||
""" | |||
def __init__(self, **kwargs): | |||
super().__init__(**kwargs) | |||
self.first_sequence: str = kwargs.pop('first_sequence', | |||
'first_sequence') | |||
self.first_sequence: str = kwargs.pop('first_sequence', 'tokens') | |||
self.label = kwargs.pop('label', OutputKeys.LABELS) | |||
def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: | |||
@@ -80,10 +79,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
'is_split_into_words', False) | |||
if 'label2id' in kwargs: | |||
kwargs.pop('label2id') | |||
self.tokenize_kwargs = kwargs | |||
@type_assert(object, str) | |||
def __call__(self, data: str) -> Dict[str, Any]: | |||
@type_assert(object, (str, dict)) | |||
def __call__(self, data: Union[dict, str]) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
@@ -99,18 +97,24 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
text = None | |||
labels_list = None | |||
if isinstance(data, str): | |||
# for inference inputs without label | |||
text = data | |||
self.tokenize_kwargs['add_special_tokens'] = False | |||
elif isinstance(data, dict): | |||
# for finetune inputs with label | |||
text = data.get(self.first_sequence) | |||
labels_list = data.get(self.label) | |||
if isinstance(text, list): | |||
self.tokenize_kwargs['is_split_into_words'] = True | |||
input_ids = [] | |||
label_mask = [] | |||
offset_mapping = [] | |||
if self.is_split_into_words: | |||
for offset, token in enumerate(list(data)): | |||
subtoken_ids = self.tokenizer.encode( | |||
token, add_special_tokens=False) | |||
token_type_ids = [] | |||
if self.is_split_into_words and self._mode == ModeKeys.INFERENCE: | |||
for offset, token in enumerate(list(text)): | |||
subtoken_ids = self.tokenizer.encode(token, | |||
**self.tokenize_kwargs) | |||
if len(subtoken_ids) == 0: | |||
subtoken_ids = [self.tokenizer.unk_token_id] | |||
input_ids.extend(subtoken_ids) | |||
@@ -119,10 +123,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
else: | |||
if self.tokenizer.is_fast: | |||
encodings = self.tokenizer( | |||
text, | |||
add_special_tokens=False, | |||
return_offsets_mapping=True, | |||
**self.tokenize_kwargs) | |||
text, return_offsets_mapping=True, **self.tokenize_kwargs) | |||
attention_mask = encodings['attention_mask'] | |||
token_type_ids = encodings['token_type_ids'] | |||
input_ids = encodings['input_ids'] | |||
word_ids = encodings.word_ids() | |||
for i in range(len(word_ids)): | |||
@@ -137,75 +140,85 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
label_mask.append(1) | |||
offset_mapping.append(encodings['offset_mapping'][i]) | |||
else: | |||
encodings = self.tokenizer( | |||
text, add_special_tokens=False, **self.tokenize_kwargs) | |||
encodings = self.tokenizer(text, **self.tokenize_kwargs) | |||
input_ids = encodings['input_ids'] | |||
label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( | |||
text) | |||
if len(input_ids) >= self.sequence_length - 2: | |||
input_ids = input_ids[:self.sequence_length - 2] | |||
label_mask = label_mask[:self.sequence_length - 2] | |||
input_ids = [self.tokenizer.cls_token_id | |||
] + input_ids + [self.tokenizer.sep_token_id] | |||
label_mask = [0] + label_mask + [0] | |||
attention_mask = [1] * len(input_ids) | |||
offset_mapping = offset_mapping[:sum(label_mask)] | |||
if self._mode == ModeKeys.INFERENCE: | |||
if len(input_ids) >= self.sequence_length - 2: | |||
input_ids = input_ids[:self.sequence_length - 2] | |||
label_mask = label_mask[:self.sequence_length - 2] | |||
input_ids = [self.tokenizer.cls_token_id | |||
] + input_ids + [self.tokenizer.sep_token_id] | |||
label_mask = [0] + label_mask + [0] | |||
attention_mask = [1] * len(input_ids) | |||
offset_mapping = offset_mapping[:sum(label_mask)] | |||
if not self.is_transformer_based_model: | |||
input_ids = input_ids[1:-1] | |||
attention_mask = attention_mask[1:-1] | |||
label_mask = label_mask[1:-1] | |||
if not self.is_transformer_based_model: | |||
input_ids = input_ids[1:-1] | |||
attention_mask = attention_mask[1:-1] | |||
label_mask = label_mask[1:-1] | |||
if self._mode == ModeKeys.INFERENCE: | |||
input_ids = torch.tensor(input_ids).unsqueeze(0) | |||
attention_mask = torch.tensor(attention_mask).unsqueeze(0) | |||
label_mask = torch.tensor( | |||
label_mask, dtype=torch.bool).unsqueeze(0) | |||
# the token classification | |||
output = { | |||
'text': text, | |||
'input_ids': input_ids, | |||
'attention_mask': attention_mask, | |||
'label_mask': label_mask, | |||
'offset_mapping': offset_mapping | |||
} | |||
# align the labels with tokenized text | |||
if labels_list is not None: | |||
assert self.label2id is not None | |||
# Map that sends B-Xxx label to its I-Xxx counterpart | |||
b_to_i_label = [] | |||
label_enumerate_values = [ | |||
k for k, v in sorted( | |||
self.label2id.items(), key=lambda item: item[1]) | |||
] | |||
for idx, label in enumerate(label_enumerate_values): | |||
if label.startswith('B-') and label.replace( | |||
'B-', 'I-') in label_enumerate_values: | |||
b_to_i_label.append( | |||
label_enumerate_values.index( | |||
label.replace('B-', 'I-'))) | |||
else: | |||
b_to_i_label.append(idx) | |||
# the token classification | |||
output = { | |||
'text': text, | |||
'input_ids': input_ids, | |||
'attention_mask': attention_mask, | |||
'label_mask': label_mask, | |||
'offset_mapping': offset_mapping | |||
} | |||
else: | |||
output = { | |||
'input_ids': input_ids, | |||
'token_type_ids': token_type_ids, | |||
'attention_mask': attention_mask, | |||
'label_mask': label_mask, | |||
} | |||
label_row = [self.label2id[lb] for lb in labels_list] | |||
previous_word_idx = None | |||
label_ids = [] | |||
for word_idx in word_ids: | |||
if word_idx is None: | |||
label_ids.append(-100) | |||
elif word_idx != previous_word_idx: | |||
label_ids.append(label_row[word_idx]) | |||
else: | |||
if self.label_all_tokens: | |||
label_ids.append(b_to_i_label[label_row[word_idx]]) | |||
# align the labels with tokenized text | |||
if labels_list is not None: | |||
assert self.label2id is not None | |||
# Map that sends B-Xxx label to its I-Xxx counterpart | |||
b_to_i_label = [] | |||
label_enumerate_values = [ | |||
k for k, v in sorted( | |||
self.label2id.items(), key=lambda item: item[1]) | |||
] | |||
for idx, label in enumerate(label_enumerate_values): | |||
if label.startswith('B-') and label.replace( | |||
'B-', 'I-') in label_enumerate_values: | |||
b_to_i_label.append( | |||
label_enumerate_values.index( | |||
label.replace('B-', 'I-'))) | |||
else: | |||
b_to_i_label.append(idx) | |||
label_row = [self.label2id[lb] for lb in labels_list] | |||
previous_word_idx = None | |||
label_ids = [] | |||
for word_idx in word_ids: | |||
if word_idx is None: | |||
label_ids.append(-100) | |||
previous_word_idx = word_idx | |||
labels = label_ids | |||
output['labels'] = labels | |||
elif word_idx != previous_word_idx: | |||
label_ids.append(label_row[word_idx]) | |||
else: | |||
if self.label_all_tokens: | |||
label_ids.append(b_to_i_label[label_row[word_idx]]) | |||
else: | |||
label_ids.append(-100) | |||
previous_word_idx = word_idx | |||
labels = label_ids | |||
output['labels'] = labels | |||
output = { | |||
k: np.array(v) if isinstance(v, list) else v | |||
for k, v in output.items() | |||
} | |||
return output | |||
def get_tokenizer_class(self): | |||
@@ -2,12 +2,12 @@ | |||
from typing import Any, Dict | |||
import torch | |||
from PIL import Image | |||
import unicodedata2 | |||
from torchvision import transforms | |||
from torchvision.transforms import InterpolationMode | |||
from torchvision.transforms import functional as F | |||
from zhconv import convert | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
@@ -73,21 +73,14 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
""" | |||
super(OfaOcrRecognitionPreprocessor, | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
# Initialize transform | |||
if self.cfg.model.imagenet_default_mean_and_std: | |||
mean = IMAGENET_DEFAULT_MEAN | |||
std = IMAGENET_DEFAULT_STD | |||
else: | |||
mean = [0.5, 0.5, 0.5] | |||
std = [0.5, 0.5, 0.5] | |||
self.patch_resize_transform = transforms.Compose([ | |||
lambda image: ocr_resize( | |||
image, | |||
self.cfg.model.patch_image_size, | |||
is_document=self.cfg.model.is_document), | |||
self.patch_image_size, | |||
is_document=self.cfg.model.get('is_document', False)), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=mean, std=std), | |||
transforms.Normalize(mean=self.mean, std=self.std), | |||
]) | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
@@ -98,8 +91,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
sample = self._build_infer_sample(data) | |||
target = data[self.column_map['text']] | |||
target = target.translate(self.transtab).strip() | |||
target = sample['label'] | |||
target_token_list = target.strip().split() | |||
target = ' '.join(target_token_list[:self.max_tgt_length]) | |||
sample['target'] = self.tokenize_text(target, add_bos=False) | |||
@@ -119,5 +111,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
'patch_mask': torch.tensor([True]) | |||
} | |||
if 'text' in self.column_map and self.column_map['text'] in data: | |||
sample['label'] = data[self.column_map['text']] | |||
target = data[self.column_map['text']] | |||
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) | |||
sample['label'] = target | |||
return sample |
@@ -0,0 +1,18 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
from modelscope.metainfo import Hooks | |||
from modelscope.trainers.multi_modal.clip.clip_trainer import CLIPTrainer | |||
from .builder import HOOKS | |||
from .hook import Hook | |||
@HOOKS.register_module(module_name=Hooks.ClipClampLogitScaleHook) | |||
class ClipClampLogitScaleHook(Hook): | |||
"""ClipClampLogitScaleHook hook which performs clamp on CLIP logit scale parameter after update""" | |||
def after_train_iter(self, trainer: CLIPTrainer): | |||
"""Called after every training iter to evaluate the results.""" | |||
unwrapped_model = getattr(trainer.model, 'module', trainer.model) | |||
logit_scale = unwrapped_model.clip_model.logit_scale | |||
logit_scale.data = torch.clamp(logit_scale.data, 0, 4.6052) |
@@ -61,7 +61,7 @@ class TextLoggerHook(LoggerHook): | |||
self.json_log_path = osp.join(self.out_dir, | |||
'{}.log.json'.format(trainer.timestamp)) | |||
if hasattr(trainer, 'meta') and trainer.meta is not None: | |||
self._dump_log(trainer.meta, trainer) | |||
self._dump_log(trainer.meta) | |||
def _get_max_memory(self, trainer): | |||
device = getattr(trainer.model, 'output_device', None) | |||
@@ -1,169 +1,206 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
import os | |||
from typing import Dict, Optional | |||
from typing import Callable, Dict, Optional, Tuple, Union | |||
import torch | |||
import torch.distributed as dist | |||
from torch.utils.data import DataLoader | |||
from torch.utils.data.distributed import DistributedSampler | |||
from torch import distributed as dist | |||
from torch import nn | |||
from torch.utils.data import Dataset | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.base import Model | |||
from modelscope.trainers.base import BaseTrainer | |||
from modelscope.models.base import Model, TorchModel | |||
from modelscope.models.multi_modal.clip.model import convert_models_to_fp32 | |||
from modelscope.msdatasets.ms_dataset import MsDataset | |||
from modelscope.preprocessors.base import Preprocessor | |||
from modelscope.preprocessors.multi_modal import CLIPPreprocessor | |||
from modelscope.trainers import EpochBasedTrainer | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.optimizer.builder import build_optimizer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModeKeys | |||
from modelscope.utils.logger import get_logger | |||
from .clip_trainer_utils import ImageWithCaptionDataset, get_optimizer | |||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | |||
ModeKeys) | |||
from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule | |||
logger = get_logger() | |||
def exclude(n): | |||
return 'bn' in n or 'ln' in n or 'bias' in n or 'logit_scale' in n | |||
def include(n): | |||
return not exclude(n) | |||
@TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) | |||
class CLIPTrainer(BaseTrainer): | |||
def __init__(self, cfg_file: str, model: str, device_id: int, *args, | |||
**kwargs): | |||
super().__init__(cfg_file) | |||
self.cfg = Config.from_file(cfg_file) | |||
self.model = Model.from_pretrained(model) | |||
self.device_id = device_id | |||
self.total_epoch = self.cfg.train.epoch | |||
self.train_batch_size = self.cfg.train.batch_size | |||
self.val_batch_size = self.cfg.evaluation.batch_size | |||
self.ckpt_dir = self.cfg.train.ckpt_dir | |||
self.train_dataset = ImageWithCaptionDataset( | |||
json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||
self.cfg.dataset.train_set), | |||
img_dir=self.cfg.dataset.root_dir, | |||
phase=ModeKeys.TRAIN) | |||
self.val_dataset = ImageWithCaptionDataset( | |||
json_file='{}/{}'.format(self.cfg.dataset.root_dir, | |||
self.cfg.dataset.val_set), | |||
img_dir=self.cfg.dataset.root_dir, | |||
phase=ModeKeys.EVAL) | |||
def train(self, *args, **kwargs): | |||
assert dist.is_initialized() | |||
self.model.clip_model.train() | |||
self.model.clip_model.to(self.device_id) | |||
ddp_model = torch.nn.parallel.DistributedDataParallel( | |||
self.model.clip_model, device_ids=[ | |||
self.device_id, | |||
]) | |||
optimizer = get_optimizer(ddp_model) | |||
for epoch in range(self.total_epoch): | |||
train_sampler = DistributedSampler( | |||
dataset=self.train_dataset, shuffle=True) | |||
train_sampler.set_epoch(epoch) | |||
train_params = { | |||
'pin_memory': True, | |||
'collate_fn': None, | |||
'batch_size': self.train_batch_size, | |||
'shuffle': False, | |||
'drop_last': True, | |||
'sampler': train_sampler, | |||
'num_workers': 8 | |||
class CLIPTrainer(EpochBasedTrainer): | |||
def __init__( | |||
self, | |||
model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||
cfg_file: Optional[str] = None, | |||
arg_parse_fn: Optional[Callable] = None, | |||
data_collator: Optional[Union[Callable, Dict[str, | |||
Callable]]] = None, | |||
train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
preprocessor: Optional[Union[Preprocessor, | |||
Dict[str, Preprocessor]]] = None, | |||
optimizers: Tuple[torch.optim.Optimizer, | |||
torch.optim.lr_scheduler._LRScheduler] = (None, | |||
None), | |||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
seed: int = 42, | |||
**kwargs): | |||
model = Model.from_pretrained(model, revision=model_revision) | |||
# for training & eval, we convert the model from FP16 back to FP32 | |||
# to compatible with modelscope amp training | |||
convert_models_to_fp32(model) | |||
cfg = Config.from_file(cfg_file) | |||
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: | |||
work_dir = cfg.train.work_dir | |||
else: | |||
work_dir = kwargs['work_dir'] | |||
# fetch the model name of CLIP model (base, large or large-336) | |||
model_name = cfg.pretrained_model.model_name | |||
# world size | |||
world_size = int(os.environ.get('WORLD_SIZE', 1)) | |||
# train step, optimizer and lr_scheduler | |||
epoch_steps = math.ceil( | |||
len(train_dataset) / # noqa | |||
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | |||
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | |||
if optimizers[0] is None: | |||
named_parameters = list(model.named_parameters()) | |||
gain_or_bias_params = [ | |||
p for n, p in named_parameters | |||
if exclude(n) and p.requires_grad | |||
] | |||
rest_params = [ | |||
p for n, p in named_parameters | |||
if include(n) and p.requires_grad | |||
] | |||
optimizer_hparams = get_optimizer_params( | |||
model_name, cfg) # lr, wd, beta1, beta2, eps | |||
optimizer_args = { | |||
'params': [ | |||
{ | |||
'params': gain_or_bias_params, | |||
'weight_decay': 0. | |||
}, | |||
{ | |||
'params': rest_params, | |||
'weight_decay': optimizer_hparams['weight_decay'] | |||
}, | |||
], | |||
'lr': | |||
optimizer_hparams['lr'], | |||
'betas': | |||
(optimizer_hparams['beta1'], optimizer_hparams['beta2']), | |||
'eps': | |||
optimizer_hparams['eps'], | |||
} | |||
optimizer = build_optimizer( | |||
model, cfg=cfg.train.optimizer, default_args=optimizer_args) | |||
else: | |||
optimizer = optimizers[0] | |||
if optimizers[1] is None: | |||
lr_scheduler = get_schedule(optimizer, cfg.train.lr_scheduler) | |||
else: | |||
lr_scheduler = optimizers[1] | |||
optimizers = (optimizer, lr_scheduler) | |||
# loss module | |||
loss_img = nn.CrossEntropyLoss() | |||
loss_txt = nn.CrossEntropyLoss() | |||
self.loss_img = loss_img.cuda(int(os.environ.get('LOCAL_RANK', 0))) | |||
self.loss_txt = loss_txt.cuda(int(os.environ.get('LOCAL_RANK', 0))) | |||
self.loss_cfg = cfg.train.loss_cfg | |||
# launcher and use_fp16 | |||
if 'launcher' not in kwargs and cfg.train.get('launcher', None): | |||
kwargs['launcher'] = cfg.train.launcher | |||
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | |||
kwargs['use_fp16'] = cfg.train.use_fp16 | |||
# preprocessor | |||
if preprocessor is None: | |||
preprocessor = { | |||
ConfigKeys.train: | |||
CLIPPreprocessor( | |||
model_dir=work_dir, | |||
mode=ModeKeys.TRAIN, | |||
tokenizer=model.tokenizer, | |||
resolution=model.model_info['image_resolution']), | |||
ConfigKeys.val: | |||
CLIPPreprocessor( | |||
model_dir=work_dir, | |||
mode=ModeKeys.EVAL, | |||
tokenizer=model.tokenizer, | |||
resolution=model.model_info['image_resolution']), | |||
} | |||
train_loader = DataLoader(self.train_dataset, **train_params) | |||
for batch_idx, (img_tensor, text_str_list, | |||
img_id_list) in enumerate(train_loader): | |||
text_info_list = [ | |||
self.model.tokenize_text(tmp) for tmp in text_str_list | |||
] | |||
text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||
dim=0) | |||
text_masks_tensor = torch.cat( | |||
[tmp[1] for tmp in text_info_list], dim=0) | |||
img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||
text_ids_tensor = text_ids_tensor.to( | |||
self.device_id, non_blocking=True) | |||
text_masks_tensor = text_masks_tensor.to( | |||
self.device_id, non_blocking=True) | |||
loss = ddp_model((img_tensor, text_ids_tensor, | |||
text_masks_tensor, img_id_list), | |||
ModeKeys.TRAIN) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
if batch_idx % 10 == 0: | |||
logger.info( | |||
'epoch: {}, train batch {}/{}, loss={:.5f}, logit_scale={:.5f}' | |||
.format(epoch, batch_idx, len(train_loader), | |||
loss.item(), | |||
ddp_model.module.logit_scale.exp().item())) | |||
if dist.get_rank() == 0: | |||
os.makedirs(self.ckpt_dir, exist_ok=True) | |||
torch.save(ddp_model.module.state_dict(), | |||
'{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||
def evaluate(self, | |||
checkpoint_path: Optional[str] = None, | |||
*args, | |||
**kwargs) -> Dict[str, float]: | |||
if checkpoint_path is not None: | |||
checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||
self.model.clip_model.load_state_dict(checkpoint_params) | |||
self.model.clip_model.eval() | |||
self.model.clip_model.to(self.device_id) | |||
val_params = { | |||
'collate_fn': None, | |||
'batch_size': self.val_batch_size, | |||
'shuffle': False, | |||
'drop_last': False, | |||
'num_workers': 8 | |||
} | |||
val_loader = DataLoader(self.val_dataset, **val_params) | |||
tp_cnt_per_batch = [] | |||
processed_cnt = 0 | |||
with torch.no_grad(): | |||
for batch_idx, (img_tensor, text_str_list, | |||
img_id_list) in enumerate(val_loader): | |||
text_info_list = [ | |||
self.model.tokenize_text(tmp) for tmp in text_str_list | |||
] | |||
text_ids_tensor = torch.cat([tmp[0] for tmp in text_info_list], | |||
dim=0) | |||
text_masks_tensor = torch.cat( | |||
[tmp[1] for tmp in text_info_list], dim=0) | |||
img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
img_id_list = img_id_list.to(self.device_id, non_blocking=True) | |||
text_ids_tensor = text_ids_tensor.to( | |||
self.device_id, non_blocking=True) | |||
text_masks_tensor = text_masks_tensor.to( | |||
self.device_id, non_blocking=True) | |||
img_feat = self.model.clip_model(img_tensor, input_type='img') | |||
text_feat = self.model.clip_model( | |||
(text_ids_tensor, text_masks_tensor), input_type='text') | |||
sim_mat = text_feat @ img_feat.t() | |||
text_cnt, img_cnt = sim_mat.shape | |||
top1_scores, match_ids = torch.max(sim_mat, dim=1) | |||
match_ids = match_ids.int() | |||
gt_ids = torch.tensor(range(0, text_cnt)).to( | |||
self.device_id, non_blocking=True).int() | |||
error_cnt = torch.nonzero(match_ids - gt_ids) | |||
processed_cnt += text_cnt | |||
tp_cnt_per_batch.append(text_cnt - 1.0 * error_cnt.numel()) | |||
logger.info('current acc: {:.3f}'.format( | |||
sum(tp_cnt_per_batch) / processed_cnt)) | |||
# dataset related | |||
self.dataset_cfg = cfg.dataset | |||
if hasattr(self.dataset_cfg, 'column_map'): | |||
# cases where dataset key names are not "img" and "text" | |||
img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img') | |||
preprocessor[ConfigKeys.train].set_input_img_key(img_key_name) | |||
preprocessor[ConfigKeys.val].set_input_img_key(img_key_name) | |||
text_key_name = getattr(self.dataset_cfg.column_map, 'text', | |||
'text') | |||
preprocessor[ConfigKeys.train].set_input_text_key(text_key_name) | |||
preprocessor[ConfigKeys.val].set_input_text_key(text_key_name) | |||
self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size | |||
super().__init__( | |||
model=model, | |||
cfg_file=cfg_file, | |||
arg_parse_fn=arg_parse_fn, | |||
data_collator=data_collator, | |||
train_dataset=train_dataset, | |||
eval_dataset=eval_dataset, | |||
preprocessor=preprocessor, | |||
optimizers=optimizers, | |||
seed=seed, | |||
**kwargs, | |||
) | |||
def train_step(self, model, inputs): | |||
model.train() | |||
inputs['mode'] = ModeKeys.TRAIN | |||
model_outputs = model.forward( | |||
inputs | |||
) # {OutputKeys.IMG_EMBEDDING: Tensor(batch_size, dim), OutputKeys.TEXT_EMBEDDING: Tensor(batch_size, dim)} | |||
loss = get_loss(model_outputs, self.loss_img, self.loss_txt, | |||
self.loss_cfg) | |||
train_outputs = {'loss': loss} | |||
# add model output info to log | |||
if 'log_vars' not in train_outputs: | |||
default_keys_pattern = ['loss'] | |||
match_keys = set([]) | |||
for key_p in default_keys_pattern: | |||
match_keys.update( | |||
[key for key in train_outputs.keys() if key_p in key]) | |||
log_vars = {} | |||
for key in match_keys: | |||
value = train_outputs.get(key, None) | |||
if value is not None: | |||
if dist.is_available() and dist.is_initialized(): | |||
value = value.data.clone() | |||
dist.all_reduce(value.div_(dist.get_world_size())) | |||
log_vars.update({key: value.item()}) | |||
unwrapped_model = getattr(model, 'module', model) | |||
log_vars[ | |||
'logit_scale'] = unwrapped_model.clip_model.logit_scale.data.clone( | |||
).item() # noqa | |||
log_vars['global_batch_size'] = int(self.global_batch_size) | |||
self.log_buffer.update(log_vars) | |||
else: | |||
self.log_buffer.update(train_outputs['log_vars']) | |||
self.train_outputs = train_outputs |
@@ -1,94 +1,125 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2022 The OFA-Sys Team. | |||
# All rights reserved. | |||
# This source code is licensed under the Apache 2.0 license | |||
# found in the LICENSE file in the root directory. | |||
import math | |||
import os | |||
import random | |||
from functools import partial | |||
from inspect import unwrap | |||
import json | |||
import torch | |||
import torch.nn.functional as F | |||
from PIL import Image | |||
from torch.utils.data import Dataset | |||
from torchvision import transforms | |||
from modelscope.utils.constant import ModeKeys | |||
train_transform = transforms.Compose([ | |||
transforms.RandomResizedCrop( | |||
224, scale=(0.5, 1.0), interpolation=Image.BICUBIC), | |||
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], | |||
p=0.8), | |||
transforms.RandomGrayscale(p=0.2), | |||
transforms.RandomHorizontalFlip(), | |||
transforms.ToTensor(), | |||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)) | |||
]) | |||
val_transform = transforms.Compose([ | |||
transforms.Resize((224, 224), interpolation=Image.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)) | |||
]) | |||
class ImageWithCaptionDataset(Dataset): | |||
def __init__(self, json_file, img_dir, phase): | |||
self.annotations = json.load(open(json_file)) | |||
self.img_dir = img_dir | |||
if phase == ModeKeys.TRAIN: | |||
self.transform = train_transform | |||
elif phase == ModeKeys.EVAL: | |||
self.transform = val_transform | |||
self.img_name2img_id = {} | |||
for anno_dict in self.annotations: | |||
img_name = anno_dict['image'] | |||
if img_name not in self.img_name2img_id: | |||
self.img_name2img_id[img_name] = len(self.img_name2img_id) | |||
def __len__(self): | |||
return len(self.annotations) | |||
def __getitem__(self, index): | |||
anno_dict = self.annotations[index] | |||
img_path = os.path.join(self.img_dir, anno_dict['image']) | |||
img_pil = Image.open(img_path).convert('RGB') | |||
img_th = self.transform(img_pil) | |||
img_id = self.img_name2img_id[anno_dict['image']] | |||
text_str = random.choice(anno_dict['caption']) | |||
return img_th, text_str, img_id | |||
def get_params_groups(ddp_model, weight_decay): | |||
decay = [] | |||
no_decay = [] | |||
for name, param in ddp_model.named_parameters(): | |||
if not param.requires_grad: | |||
continue | |||
if len(param.shape) == 1 or name.endswith('.bias'): | |||
no_decay.append(param) | |||
else: | |||
decay.append(param) | |||
params_groups = [{ | |||
'params': no_decay, | |||
'weight_decay': 0. | |||
}, { | |||
'params': decay, | |||
'weight_decay': weight_decay | |||
}] | |||
return params_groups | |||
def get_optimizer(ddp_model): | |||
from torch.optim import AdamW | |||
lr_init = 1e-5 | |||
betas = [0.9, 0.999] | |||
weight_decay = 0.02 | |||
params_groups = get_params_groups(ddp_model, weight_decay=weight_decay) | |||
return AdamW( | |||
params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) | |||
import torch.distributed as dist | |||
from torch.optim.lr_scheduler import LambdaLR | |||
from modelscope.outputs import OutputKeys | |||
def get_optimizer_params(model_name, cfg): | |||
# get default params | |||
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf) | |||
# base model | |||
if model_name in ['damo/multi-modal_clip-vit-base-patch16_zh']: | |||
params = { | |||
'lr': 5.0e-4, | |||
'beta1': 0.9, | |||
'beta2': 0.98, | |||
'eps': 1.0e-6, | |||
'weight_decay': 0.0 | |||
} | |||
# large models | |||
elif model_name in [ | |||
'damo/multi-modal_clip-vit-large-patch14_zh', | |||
'damo/multi-modal_clip-vit-large-patch14_336_zh' | |||
]: | |||
params = { | |||
'lr': 4.0e-4, | |||
'beta1': 0.9, | |||
'beta2': 0.98, | |||
'eps': 1.0e-6, | |||
'weight_decay': 0.0 | |||
} | |||
else: | |||
params = { | |||
'lr': 5.0e-4, | |||
'beta1': 0.9, | |||
'beta2': 0.999, | |||
'eps': 1.0e-8, | |||
'weight_decay': 0.0 | |||
} | |||
# override with config params | |||
for key in ['lr', 'beta1', 'beta2', 'eps', 'weight_decay']: | |||
if hasattr(cfg.train, 'optimizer_hparams'): | |||
params[key] = getattr(cfg.train.optimizer_hparams, key, | |||
params[key]) | |||
return params | |||
def get_loss(model_outputs, loss_img, loss_txt, loss_cfg): | |||
image_features = model_outputs[OutputKeys.IMG_EMBEDDING] | |||
text_features = model_outputs[OutputKeys.TEXT_EMBEDDING] | |||
logit_scale = model_outputs['logit_scale'] | |||
logit_scale = logit_scale.mean() | |||
if loss_cfg.aggregate and int(os.environ.get('WORLD_SIZE', 1)) > 1: | |||
world_size = dist.get_world_size() | |||
rank = dist.get_rank() | |||
# We gather tensors from all gpus to get more negatives to contrast with. | |||
gathered_image_features = [ | |||
torch.zeros_like(image_features) for _ in range(world_size) | |||
] | |||
gathered_text_features = [ | |||
torch.zeros_like(text_features) for _ in range(world_size) | |||
] | |||
dist.all_gather(gathered_image_features, image_features) | |||
dist.all_gather(gathered_text_features, text_features) | |||
all_image_features = torch.cat([image_features] | |||
+ gathered_image_features[:rank] | |||
+ gathered_image_features[rank + 1:]) | |||
all_text_features = torch.cat([text_features] | |||
+ gathered_text_features[:rank] | |||
+ gathered_text_features[rank + 1:]) | |||
# this is needed to send gradients back everywhere. | |||
logits_per_image = logit_scale * all_image_features @ all_text_features.t( | |||
) | |||
logits_per_text = logits_per_image.t() | |||
else: | |||
logits_per_image = logit_scale * image_features @ text_features.t() | |||
logits_per_text = logit_scale * text_features @ image_features.t() | |||
ground_truth = torch.arange(len(logits_per_image)).long() | |||
ground_truth = ground_truth.cuda( | |||
int(os.environ.get('LOCAL_RANK', 0)), non_blocking=True) | |||
total_loss = (loss_img(logits_per_image, ground_truth) | |||
+ loss_txt(logits_per_text, ground_truth)) / 2 | |||
return total_loss | |||
def lr_lambda(num_warmup_steps, num_training_steps, num_cycles, current_step): | |||
if current_step < num_warmup_steps: | |||
return float(current_step) / float(max(1, num_warmup_steps)) | |||
progress = float(current_step - num_warmup_steps) / float( | |||
max(1, num_training_steps - num_warmup_steps)) | |||
return max( | |||
0.0, | |||
0.5 * # noqa | |||
(1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) # noqa | |||
def get_schedule(optimizer, | |||
scheduler, | |||
num_cycles: float = 0.5, | |||
last_epoch: int = -1): | |||
num_warmup_steps = int(scheduler.warmup_proportion | |||
* scheduler.num_train_steps) | |||
num_training_steps = scheduler.num_train_steps | |||
return LambdaLR( | |||
optimizer, | |||
partial(lr_lambda, num_warmup_steps, num_training_steps, num_cycles), | |||
last_epoch) |
@@ -103,20 +103,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
def __init__(self, args): | |||
super().__init__() | |||
self.sentence_avg = args.sentence_avg | |||
self.eps = args.label_smoothing | |||
self.ignore_prefix_size = args.ignore_prefix_size | |||
self.ignore_eos = args.ignore_eos | |||
self.report_accuracy = args.report_accuracy | |||
self.drop_worst_ratio = args.drop_worst_ratio | |||
self.drop_worst_after = args.drop_worst_after | |||
self.use_rdrop = args.use_rdrop | |||
self.reg_alpha = args.reg_alpha | |||
self.sample_patch_num = args.sample_patch_num | |||
self.sentence_avg = args.get('sentence_avg', False) | |||
self.eps = args.get('label_smoothing', 0.1) | |||
self.ignore_prefix_size = args.get('ignore_prefix_size', 0) | |||
self.ignore_eos = args.get('ignore_eos', False) | |||
self.report_accuracy = args.get('report_accuracy', False) | |||
self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0) | |||
self.drop_worst_after = args.get('drop_worst_after', 0) | |||
self.use_rdrop = args.get('use_rdrop', False) | |||
self.reg_alpha = args.get('reg_alpha', 1.0) | |||
self.sample_patch_num = args.get('sample_patch_num', 196) | |||
self.constraint_start = None | |||
self.constraint_end = None | |||
if args.constraint_range: | |||
if args.get('constraint_range', None): | |||
constraint_start, constraint_end = args.constraint_range.split(',') | |||
self.constraint_start = int(constraint_start) | |||
self.constraint_end = int(constraint_end) | |||
@@ -18,7 +18,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer): | |||
return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) | |||
def evaluation_step(self, data): | |||
model = self.model | |||
model = self.model.module if self._dist else self.model | |||
model.eval() | |||
with torch.no_grad(): | |||
@@ -586,14 +586,16 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||
preprocessor_mode=ModeKeys.TRAIN, | |||
**model_args, | |||
**self.train_keys, | |||
mode=ModeKeys.TRAIN) | |||
mode=ModeKeys.TRAIN, | |||
use_fast=True) | |||
eval_preprocessor = Preprocessor.from_pretrained( | |||
self.model_dir, | |||
cfg_dict=self.cfg, | |||
preprocessor_mode=ModeKeys.EVAL, | |||
**model_args, | |||
**self.eval_keys, | |||
mode=ModeKeys.EVAL) | |||
mode=ModeKeys.EVAL, | |||
use_fast=True) | |||
return train_preprocessor, eval_preprocessor | |||
@@ -15,6 +15,7 @@ from torch.utils.data.dataloader import default_collate | |||
from torch.utils.data.distributed import DistributedSampler | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.hub.utils.utils import create_library_statistics | |||
from modelscope.metainfo import Trainers | |||
from modelscope.metrics import build_metric, task_default_metrics | |||
from modelscope.models.base import Model, TorchModel | |||
@@ -183,7 +184,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
preprocessor=self.eval_preprocessor, | |||
**kwargs) | |||
self.train_data_collator, self.eval_default_collate = None, None | |||
self.train_data_collator, self.eval_data_collator = None, None | |||
if isinstance(data_collator, Mapping): | |||
if not (ConfigKeys.train in data_collator | |||
or ConfigKeys.val in data_collator): | |||
@@ -436,6 +437,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
def train(self, checkpoint_path=None, *args, **kwargs): | |||
self._mode = ModeKeys.TRAIN | |||
if hasattr(self.model, 'name'): | |||
create_library_statistics('train', self.model.name, None) | |||
if self.train_dataset is None: | |||
self.train_dataloader = self.get_train_dataloader() | |||
@@ -456,6 +459,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
self.train_loop(self.train_dataloader) | |||
def evaluate(self, checkpoint_path=None): | |||
if hasattr(self.model, 'name'): | |||
create_library_statistics('evaluate', self.model.name, None) | |||
if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
from modelscope.trainers.hooks import CheckpointHook | |||
CheckpointHook.load_checkpoint(checkpoint_path, self) | |||
@@ -672,7 +677,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
self.model, cfg=cfg, default_args=default_args) | |||
except KeyError as e: | |||
self.logger.error( | |||
f'Build optimizer error, the optimizer {cfg} is native torch optimizer, ' | |||
f'Build optimizer error, the optimizer {cfg} is a torch native component, ' | |||
f'please check if your torch with version: {torch.__version__} matches the config.' | |||
) | |||
raise e | |||
@@ -682,7 +687,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||
except KeyError as e: | |||
self.logger.error( | |||
f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, ' | |||
f'Build lr_scheduler error, the lr_scheduler {cfg} is a torch native component, ' | |||
f'please check if your torch with version: {torch.__version__} matches the config.' | |||
) | |||
raise e | |||
@@ -876,7 +881,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
Subclass and override to inject custom behavior. | |||
""" | |||
model = self.model | |||
model = self.model.module if self._dist else self.model | |||
model.eval() | |||
if is_parallel(model): | |||
@@ -238,6 +238,14 @@ class DownloadMode(enum.Enum): | |||
FORCE_REDOWNLOAD = 'force_redownload' | |||
class DownloadChannel(enum.Enum): | |||
""" Channels of datasets downloading for uv/pv counting. | |||
""" | |||
LOCAL = 'local' | |||
DSW = 'dsw' | |||
EAIS = 'eais' | |||
class UploadMode(enum.Enum): | |||
""" How to upload object to remote. | |||
""" | |||
@@ -91,6 +91,71 @@ def draw_keypoints(output, original_image): | |||
return image | |||
def draw_106face_keypoints(in_path, | |||
keypoints, | |||
boxes, | |||
scale=4.0, | |||
save_path=None): | |||
face_contour_point_index = [ | |||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, | |||
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 | |||
] | |||
left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33] | |||
right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42] | |||
left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66] | |||
right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75] | |||
nose_bridge_point_index = [51, 52, 53, 54] | |||
nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65] | |||
mouth_outer_point_index = [ | |||
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84 | |||
] | |||
mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96] | |||
img = cv2.imread(in_path) | |||
for i in range(len(boxes)): | |||
draw_box(img, np.array(boxes[i])) | |||
image = cv2.resize(img, dsize=None, fx=scale, fy=scale) | |||
def draw_line(point_index, image, point): | |||
for i in range(len(point_index) - 1): | |||
cur_index = point_index[i] | |||
next_index = point_index[i + 1] | |||
cur_pt = (int(point[cur_index][0] * scale), | |||
int(point[cur_index][1] * scale)) | |||
next_pt = (int(point[next_index][0] * scale), | |||
int(point[next_index][1] * scale)) | |||
cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2) | |||
for i in range(len(keypoints)): | |||
points = keypoints[i] | |||
draw_line(face_contour_point_index, image, points) | |||
draw_line(left_eye_brow_point_index, image, points) | |||
draw_line(right_eye_brow_point_index, image, points) | |||
draw_line(left_eye_point_index, image, points) | |||
draw_line(right_eye_point_index, image, points) | |||
draw_line(nose_bridge_point_index, image, points) | |||
draw_line(nose_contour_point_index, image, points) | |||
draw_line(mouth_outer_point_index, image, points) | |||
draw_line(mouth_inter_point_index, image, points) | |||
size = len(points) | |||
for i in range(size): | |||
x = int(points[i][0]) | |||
y = int(points[i][1]) | |||
cv2.putText(image, str(i), (int(x * scale), int(y * scale)), | |||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) | |||
cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0), | |||
cv2.FILLED) | |||
if save_path is not None: | |||
cv2.imwrite(save_path, image) | |||
return image | |||
def draw_face_detection_no_lm_result(img_path, detection_result): | |||
bboxes = np.array(detection_result[OutputKeys.BOXES]) | |||
scores = np.array(detection_result[OutputKeys.SCORES]) | |||
@@ -4,11 +4,11 @@ import io | |||
import cv2 | |||
import json | |||
import numpy as np | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks, TasksIODescriptions | |||
from modelscope.utils.service_utils import NumpyEncoder | |||
TASKS_INPUT_TEMPLATES = { | |||
# vision tasks | |||
@@ -234,21 +234,6 @@ class DemoCompatibilityCheck(object): | |||
return True | |||
class NumpyEncoder(json.JSONEncoder): | |||
def default(self, obj): | |||
if isinstance(obj, np.ndarray): | |||
return obj.tolist() | |||
if isinstance(obj, np.floating): | |||
return float(obj) | |||
if isinstance(obj, np.integer): | |||
return int(obj) | |||
return json.JSONEncoder.default(self, obj) | |||
def preprocess(req): | |||
in_urls = req.get('urlPaths').get('inUrls') | |||
if len(req['inputs']) == 1: | |||
@@ -19,6 +19,8 @@ import torch | |||
import torch.optim | |||
from torch import nn | |||
from modelscope.utils.service_utils import NumpyEncoder | |||
class RegressTool: | |||
"""This class is used to stop inference/training results from changing by some unaware affections by unittests. | |||
@@ -117,19 +119,6 @@ class RegressTool: | |||
with open(baseline, 'rb') as f: | |||
base = pickle.load(f) | |||
class NumpyEncoder(json.JSONEncoder): | |||
"""Special json encoder for numpy types | |||
""" | |||
def default(self, obj): | |||
if isinstance(obj, np.integer): | |||
return int(obj) | |||
elif isinstance(obj, np.floating): | |||
return float(obj) | |||
elif isinstance(obj, np.ndarray): | |||
return obj.tolist() | |||
return json.JSONEncoder.default(self, obj) | |||
print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}') | |||
print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') | |||
if not compare_io_and_print(base, io_json, compare_fn, **kwargs): | |||
@@ -0,0 +1,179 @@ | |||
import base64 | |||
import mimetypes | |||
from io import BytesIO | |||
import json | |||
import numpy as np | |||
import requests | |||
from PIL import Image | |||
from modelscope.outputs import TASK_OUTPUTS, OutputKeys | |||
from modelscope.pipeline_inputs import TASK_INPUTS, InputType | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks, TasksIODescriptions | |||
# service data decoder func decodes data from network and convert it to pipeline's input | |||
# for example | |||
def ExampleDecoder(data): | |||
# Assuming the pipeline inputs is a dict contains an image and a text, | |||
# to decode the data from network we decode the image as base64 | |||
data_json = json.loads(data) | |||
# data: {"image": "xxxxxxxx=="(base64 str), "text": "a question"} | |||
# pipeline(inputs) as follows: | |||
# pipeline({'image': image, 'text': text}) | |||
inputs = { | |||
'image': decode_base64_to_image(data_json.get('image')), | |||
'text': data_json.get('text') | |||
} | |||
return inputs | |||
# service data encoder func encodes data from pipeline outputs and convert to network response (such as json) | |||
# for example | |||
def ExampleEncoder(data): | |||
# Assuming the pipeline outputs is a dict contains an image and a text, | |||
# and transmit it through network, this func encode image to base64 and dumps into json | |||
# data (for e.g. python dict): | |||
# {"image": a numpy array represents a image, "text": "output"} | |||
image = data['image'] | |||
text = data['text'] | |||
data = {'image': encode_array_to_img_base64(image), 'text': text} | |||
return json.dumps(data, cls=NumpyEncoder) | |||
CustomEncoder = { | |||
# Tasks.visual_question_answering: ExampleEncoder | |||
} | |||
CustomDecoder = { | |||
# Tasks.visual_question_answering: ExampleDecoder | |||
} | |||
class NumpyEncoder(json.JSONEncoder): | |||
def default(self, obj): | |||
if isinstance(obj, np.ndarray): | |||
return obj.tolist() | |||
if isinstance(obj, np.floating): | |||
return float(obj) | |||
if isinstance(obj, np.integer): | |||
return int(obj) | |||
return json.JSONEncoder.default(self, obj) | |||
def get_extension(encoding): | |||
encoding = encoding.replace('audio/wav', 'audio/x-wav') | |||
tp = mimetypes.guess_type(encoding)[0] | |||
if tp == 'audio/flac': # flac is not supported by mimetypes | |||
return 'flac' | |||
extension = mimetypes.guess_extension(tp) | |||
if extension is not None and extension.startswith('.'): | |||
extension = extension[1:] | |||
return extension | |||
def get_mimetype(filename): | |||
mimetype = mimetypes.guess_type(filename)[0] | |||
if mimetype is not None: | |||
mimetype = mimetype.replace('x-wav', 'wav').replace('x-flac', 'flac') | |||
return mimetype | |||
def decode_base64_to_binary(encoding): | |||
extension = get_extension(encoding) | |||
data = encoding.split(',')[1] | |||
return base64.b64decode(data), extension | |||
def decode_base64_to_image(encoding): | |||
content = encoding.split(';')[1] | |||
image_encoded = content.split(',')[1] | |||
return Image.open(BytesIO(base64.b64decode(image_encoded))) | |||
def encode_array_to_img_base64(image_array): | |||
with BytesIO() as output_bytes: | |||
pil_image = Image.fromarray(image_array.astype(np.uint8)) | |||
pil_image.save(output_bytes, 'PNG') | |||
bytes_data = output_bytes.getvalue() | |||
base64_str = str(base64.b64encode(bytes_data), 'utf-8') | |||
return 'data:image/png;base64,' + base64_str | |||
def encode_pcm_to_base64(bytes_data): | |||
from scipy.io.wavfile import write | |||
with BytesIO() as out_mem_file: | |||
write(out_mem_file, 16000, bytes_data) | |||
base64_str = str(base64.b64encode(out_mem_file.getvalue()), 'utf-8') | |||
return 'data:audio/pcm;base64,' + base64_str | |||
def encode_url_to_base64(url): | |||
encoded_string = base64.b64encode(requests.get(url).content) | |||
base64_str = str(encoded_string, 'utf-8') | |||
mimetype = get_mimetype(url) | |||
return ('data:' + (mimetype if mimetype is not None else '') + ';base64,' | |||
+ base64_str) | |||
def encode_file_to_base64(f): | |||
with open(f, 'rb') as file: | |||
encoded_string = base64.b64encode(file.read()) | |||
base64_str = str(encoded_string, 'utf-8') | |||
mimetype = get_mimetype(f) | |||
return ('data:' + (mimetype if mimetype is not None else '') | |||
+ ';base64,' + base64_str) | |||
def encode_url_or_file_to_base64(path): | |||
try: | |||
requests.get(path) | |||
return encode_url_to_base64(path) | |||
except (requests.exceptions.MissingSchema, | |||
requests.exceptions.InvalidSchema): | |||
return encode_file_to_base64(path) | |||
def service_data_decoder(task, data): | |||
if CustomDecoder.get(task) is not None: | |||
return CustomDecoder[task](data) | |||
input_type = TASK_INPUTS[task] | |||
input_data = data.decode('utf-8') | |||
if input_type == InputType.IMAGE: | |||
return decode_base64_to_image(input_data) | |||
elif input_type == InputType.AUDIO: | |||
return decode_base64_to_binary(input_data)[0] | |||
elif input_type == InputType.TEXT: | |||
return input_data | |||
elif isinstance(input_type, dict): | |||
input_data = {} | |||
for key, val in input_type.items(): | |||
if val == InputType.IMAGE: | |||
input_data[key] = decode_base64_to_image(data[key]) | |||
elif val == InputType.AUDIO: | |||
input_data[key] = decode_base64_to_binary(data[key])[0] | |||
elif val == InputType.TEXT: | |||
input_data[key] = data[key] | |||
return input_data | |||
def service_data_encoder(task, data): | |||
if CustomEncoder.get(task) is not None: | |||
return CustomEncoder[task](data) | |||
output_keys = TASK_OUTPUTS[task] | |||
result = data | |||
for output_key in output_keys: | |||
if output_key == OutputKeys.OUTPUT_IMG: | |||
result[OutputKeys.OUTPUT_IMG] = encode_array_to_img_base64( | |||
data[OutputKeys.OUTPUT_IMG][..., ::-1]) | |||
elif output_key == OutputKeys.OUTPUT_PCM: | |||
result[OutputKeys.OUTPUT_PCM] = encode_pcm_to_base64( | |||
data[OutputKeys.OUTPUT_PCM]) | |||
result = bytes(json.dumps(result, cls=NumpyEncoder), encoding='utf8') | |||
return result |
@@ -1,5 +1,5 @@ | |||
# Make sure to modify __release_datetime__ to release time when making official release. | |||
__version__ = '0.5.0' | |||
__version__ = '1.0.0' | |||
# default release datetime for branches under active development is set | |||
# to be a time far-far-away-into-the-future | |||
__release_datetime__ = '2099-10-13 08:56:12' |
@@ -1,6 +1,7 @@ | |||
addict | |||
attrs | |||
datasets | |||
# version beyond 2.5.2 introduces compatbility issue and is being resolved | |||
datasets<=2.5.2 | |||
easydict | |||
einops | |||
filelock>=3.3.0 | |||
@@ -2,6 +2,8 @@ ftfy>=6.0.3 | |||
ofa>=0.0.2 | |||
pycocoevalcap>=1.2 | |||
pycocotools>=2.0.4 | |||
# compatible with taming-transformers-rom1504 | |||
pytorch_lightning<=1.7.7 | |||
# rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
# which introduced compatability issues that are being investigated | |||
rouge_score<=0.0.4 | |||
@@ -11,3 +13,5 @@ timm | |||
tokenizers | |||
torchvision | |||
transformers>=4.12.0 | |||
unicodedata2 | |||
zhconv |
@@ -1,6 +1,5 @@ | |||
boto3 | |||
en_core_web_sm>=2.3.5 | |||
fasttext | |||
filelock | |||
ftfy | |||
jieba>=0.42.1 | |||
@@ -1,4 +1,6 @@ | |||
biopython | |||
iopath | |||
ipdb | |||
lmdb | |||
ml_collections | |||
scipy | |||
@@ -23,7 +23,7 @@ class TestExportSbertSequenceClassification(unittest.TestCase): | |||
shutil.rmtree(self.tmp_dir) | |||
super().tearDown() | |||
@unittest.skip | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_export_sbert_sequence_classification(self): | |||
model = Model.from_pretrained(self.model_id) | |||
print( | |||
@@ -115,7 +115,7 @@ class HubRevisionTest(unittest.TestCase): | |||
time.sleep(10) | |||
self.add_new_file_and_tag_to_repo() | |||
t2 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
logger.info('Secnod time: %s' % t2) | |||
logger.info('Second time: %s' % t2) | |||
# set | |||
release_datetime_backup = version.__release_datetime__ | |||
logger.info('Origin __release_datetime__: %s' | |||
@@ -142,6 +142,43 @@ class HubRevisionTest(unittest.TestCase): | |||
finally: | |||
version.__release_datetime__ = release_datetime_backup | |||
def test_snapshot_download_revision_user_set_revision(self): | |||
with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||
self.prepare_repo_data_and_tag() | |||
t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
logger.info('First time: %s' % t1) | |||
time.sleep(10) | |||
self.add_new_file_and_tag_to_repo() | |||
t2 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
logger.info('Secnod time: %s' % t2) | |||
# set | |||
release_datetime_backup = version.__release_datetime__ | |||
logger.info('Origin __release_datetime__: %s' | |||
% version.__release_datetime__) | |||
try: | |||
logger.info('Setting __release_datetime__ to: %s' % t1) | |||
version.__release_datetime__ = t1 | |||
with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
snapshot_path = snapshot_download( | |||
self.model_id, | |||
revision=self.revision, | |||
cache_dir=temp_cache_dir) | |||
assert os.path.exists( | |||
os.path.join(snapshot_path, download_model_file_name)) | |||
assert not os.path.exists( | |||
os.path.join(snapshot_path, download_model_file_name2)) | |||
with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
snapshot_path = snapshot_download( | |||
self.model_id, | |||
revision=self.revision2, | |||
cache_dir=temp_cache_dir) | |||
assert os.path.exists( | |||
os.path.join(snapshot_path, download_model_file_name)) | |||
assert os.path.exists( | |||
os.path.join(snapshot_path, download_model_file_name2)) | |||
finally: | |||
version.__release_datetime__ = release_datetime_backup | |||
def test_file_download_revision(self): | |||
with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||
self.prepare_repo_data_and_tag() | |||
@@ -175,7 +212,6 @@ class HubRevisionTest(unittest.TestCase): | |||
self.model_id, | |||
download_model_file_name, | |||
cache_dir=temp_cache_dir) | |||
print('Downloaded file path: %s' % file_path) | |||
assert os.path.exists(file_path) | |||
file_path = model_file_download( | |||
self.model_id, | |||
@@ -185,6 +221,50 @@ class HubRevisionTest(unittest.TestCase): | |||
finally: | |||
version.__release_datetime__ = release_datetime_backup | |||
def test_file_download_revision_user_set_revision(self): | |||
with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||
self.prepare_repo_data_and_tag() | |||
t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
logger.info('First time stamp: %s' % t1) | |||
time.sleep(10) | |||
self.add_new_file_and_tag_to_repo() | |||
t2 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
logger.info('Second time: %s' % t2) | |||
release_datetime_backup = version.__release_datetime__ | |||
logger.info('Origin __release_datetime__: %s' | |||
% version.__release_datetime__) | |||
try: | |||
version.__release_datetime__ = t1 | |||
logger.info('Setting __release_datetime__ to: %s' % t1) | |||
with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
file_path = model_file_download( | |||
self.model_id, | |||
download_model_file_name, | |||
revision=self.revision, | |||
cache_dir=temp_cache_dir) | |||
assert os.path.exists(file_path) | |||
with self.assertRaises(NotExistError): | |||
model_file_download( | |||
self.model_id, | |||
download_model_file_name2, | |||
revision=self.revision, | |||
cache_dir=temp_cache_dir) | |||
with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
file_path = model_file_download( | |||
self.model_id, | |||
download_model_file_name, | |||
revision=self.revision2, | |||
cache_dir=temp_cache_dir) | |||
assert os.path.exists(file_path) | |||
file_path = model_file_download( | |||
self.model_id, | |||
download_model_file_name2, | |||
revision=self.revision2, | |||
cache_dir=temp_cache_dir) | |||
assert os.path.exists(file_path) | |||
finally: | |||
version.__release_datetime__ = release_datetime_backup | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -8,7 +8,8 @@ import zipfile | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, ModelFile | |||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DownloadMode, | |||
ModelFile) | |||
from modelscope.utils.test_utils import test_level | |||
logger = logging.get_logger(__name__) | |||
@@ -104,7 +105,10 @@ class DatasetUploadTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_ds_download_dir(self): | |||
test_ds = MsDataset.load(self.dataset_name, self.namespace) | |||
test_ds = MsDataset.load( | |||
self.dataset_name, | |||
namespace=self.namespace, | |||
download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||
assert test_ds.config_kwargs['split_config'].values() | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@@ -21,9 +21,10 @@ class TestModelOutput(unittest.TestCase): | |||
self.assertEqual(outputs['logits'], torch.Tensor([1])) | |||
self.assertEqual(outputs[0], torch.Tensor([1])) | |||
self.assertEqual(outputs.logits, torch.Tensor([1])) | |||
outputs.loss = torch.Tensor([2]) | |||
logits, loss = outputs | |||
self.assertEqual(logits, torch.Tensor([1])) | |||
self.assertTrue(loss is None) | |||
self.assertTrue(loss is not None) | |||
if __name__ == '__main__': | |||
@@ -1,11 +1,10 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
import cv2 | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.cv.image_utils import draw_106face_keypoints | |||
from modelscope.utils.test_utils import test_level | |||
@@ -13,7 +12,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_face_2d_keypoints(self): | |||
img_path = 'data/test/images/keypoints_detect/test_img_face_2d_keypoints.png' | |||
img_path = 'data/test/images/face_detection.png' | |||
model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' | |||
face_2d_keypoints_align = pipeline( | |||
@@ -21,15 +20,21 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): | |||
output = face_2d_keypoints_align(img_path) | |||
output_keypoints = output[OutputKeys.KEYPOINTS] | |||
output_pose = output[OutputKeys.POSES] | |||
img = cv2.imread(img_path) | |||
img = face_2d_keypoints_align.show_result( | |||
img, output_keypoints, scale=2, save_path='face_keypoints.jpg') | |||
self.assertEqual(output_keypoints.shape[0], 106) | |||
self.assertEqual(output_keypoints.shape[1], 2) | |||
self.assertEqual(output_pose.shape[0], 3) | |||
output_poses = output[OutputKeys.POSES] | |||
output_boxes = output[OutputKeys.BOXES] | |||
draw_106face_keypoints( | |||
img_path, | |||
output_keypoints, | |||
output_boxes, | |||
scale=2, | |||
save_path='face_keypoints.jpg') | |||
for idx in range(len(output_keypoints)): | |||
self.assertEqual(output_keypoints[idx].shape[0], 106) | |||
self.assertEqual(output_keypoints[idx].shape[1], 2) | |||
self.assertEqual(output_poses[idx].shape[0], 3) | |||
self.assertEqual(output_boxes[idx].shape[0], 4) | |||
if __name__ == '__main__': | |||
@@ -17,12 +17,12 @@ class FaceEmotionTest(unittest.TestCase): | |||
result = pipeline(input) | |||
print(result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
@unittest.skip('skip since the model is set to private for now') | |||
def test_run_modelhub(self): | |||
face_emotion = pipeline(Tasks.face_emotion, model=self.model) | |||
self.pipeline_inference(face_emotion, self.img) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skip('skip since the model is set to private for now') | |||
def test_run_modelhub_default_model(self): | |||
face_emotion = pipeline(Tasks.face_emotion) | |||
self.pipeline_inference(face_emotion, self.img) | |||
@@ -23,7 +23,7 @@ class FacialExpressionRecognitionTest(unittest.TestCase): | |||
cv2.imwrite('result.png', img) | |||
print(f'output written to {osp.abspath("result.png")}') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
@unittest.skip('skip since the model is set to private for now') | |||
def test_run_modelhub(self): | |||
fer = pipeline( | |||
Tasks.facial_expression_recognition, model=self.model_id) | |||
@@ -24,7 +24,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def test_run(self): | |||
pipeline_multi_modal_embedding = pipeline( | |||
Tasks.multi_modal_embedding, model=self.model_id) | |||
text_embedding = pipeline_multi_modal_embedding( | |||
text_embedding = pipeline_multi_modal_embedding.forward( | |||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | |||
print('l1-norm: {}'.format( | |||
torch.norm(text_embedding, p=1, dim=-1).item())) | |||
@@ -36,7 +36,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
model = Model.from_pretrained(self.model_id) | |||
pipeline_multi_modal_embedding = pipeline( | |||
task=Tasks.multi_modal_embedding, model=model) | |||
text_embedding = pipeline_multi_modal_embedding( | |||
text_embedding = pipeline_multi_modal_embedding.forward( | |||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | |||
print('l1-norm: {}'.format( | |||
torch.norm(text_embedding, p=1, dim=-1).item())) | |||
@@ -47,7 +47,7 @@ class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def test_run_with_default_model(self): | |||
pipeline_multi_modal_embedding = pipeline( | |||
task=Tasks.multi_modal_embedding) | |||
text_embedding = pipeline_multi_modal_embedding( | |||
text_embedding = pipeline_multi_modal_embedding.forward( | |||
self.test_input)[OutputKeys.TEXT_EMBEDDING] | |||
print('l1-norm: {}'.format( | |||
torch.norm(text_embedding, p=1, dim=-1).item())) | |||
@@ -19,9 +19,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
self.task = Tasks.named_entity_recognition | |||
self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' | |||
tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | |||
sentence = '这与温岭市新河镇的一个神秘的传说有关。' | |||
sentence_en = 'pizza shovel' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_tcrf_by_direct_model_download(self): | |||
@@ -89,6 +91,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
task=Tasks.named_entity_recognition, model=self.lcrf_model_id) | |||
print(pipeline_ins(input=self.sentence)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_english_with_model_name(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.named_entity_recognition, model=self.english_model_id) | |||
print(pipeline_ins(input='pizza shovel')) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_ins = pipeline(task=Tasks.named_entity_recognition) | |||
@@ -19,7 +19,7 @@ class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): | |||
self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ | |||
'NIAALKNHIDKIKPIAMQIYKKYSKNIP' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_by_direct_model_download(self): | |||
model_dir = snapshot_download(self.model_id) | |||
mono_pipeline_ins = pipeline(task=self.task, model=model_dir) | |||
@@ -50,7 +50,8 @@ class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase): | |||
trainer = build_trainer(trainer_name, kwargs) | |||
trainer.train() | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
@unittest.skip( | |||
'skip since face_2d_keypoints_dataset is set to private for now') | |||
def test_trainer_single_gpu(self): | |||
temp_file_dir = tempfile.TemporaryDirectory() | |||
tmp_dir = temp_file_dir.name | |||
@@ -0,0 +1,83 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import unittest | |||
import json | |||
from modelscope.metainfo import Metrics, Trainers | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.test_utils import test_level | |||
class TestClipTrainer(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.finetune_cfg = \ | |||
{'framework': 'pytorch', | |||
'task': 'multi-modal-embedding', | |||
'pipeline': {'type': 'multi-modal-embedding'}, | |||
'pretrained_model': {'model_name': 'damo/multi-modal_clip-vit-base-patch16_zh'}, | |||
'dataset': {'column_map': {'img': 'image', 'text': 'query'}}, | |||
'train': {'work_dir': './workspace/ckpts/clip', | |||
# 'launcher': 'pytorch', | |||
'max_epochs': 1, | |||
'use_fp16': True, | |||
'dataloader': {'batch_size_per_gpu': 8, | |||
'workers_per_gpu': 0, | |||
'shuffle': True, | |||
'drop_last': True}, | |||
'lr_scheduler': {'name': 'cosine', | |||
'warmup_proportion': 0.01}, | |||
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, | |||
'optimizer': {'type': 'AdamW'}, | |||
'optimizer_hparams': {'lr': 5e-05, 'weight_decay': 0.01}, | |||
'optimizer_hook': {'type': 'TorchAMPOptimizerHook', | |||
'cumulative_iters': 1, | |||
'loss_keys': 'loss'}, | |||
'loss_cfg': {'aggregate': True}, | |||
'hooks': [{'type': 'BestCkptSaverHook', | |||
'metric_key': 'inbatch_t2i_recall_at_1', | |||
'interval': 100}, | |||
{'type': 'TextLoggerHook', 'interval': 1}, | |||
{'type': 'IterTimerHook'}, | |||
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}, | |||
{'type': 'ClipClampLogitScaleHook'}]}, | |||
'evaluation': {'dataloader': {'batch_size_per_gpu': 8, | |||
'workers_per_gpu': 0, | |||
'shuffle': True, | |||
'drop_last': True}, | |||
'metrics': [{'type': 'inbatch_recall'}]}, | |||
'preprocessor': []} | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_std(self): | |||
WORKSPACE = './workspace/ckpts/clip' | |||
os.makedirs(WORKSPACE, exist_ok=True) | |||
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||
with open(config_file, 'w') as writer: | |||
json.dump(self.finetune_cfg, writer) | |||
pretrained_model = 'damo/multi-modal_clip-vit-base-patch16_zh' | |||
args = dict( | |||
model=pretrained_model, | |||
work_dir=WORKSPACE, | |||
train_dataset=MsDataset.load( | |||
'muge', namespace='modelscope', split='train[:200]'), | |||
eval_dataset=MsDataset.load( | |||
'muge', namespace='modelscope', split='validation[:100]'), | |||
metrics=[Metrics.inbatch_recall], | |||
cfg_file=config_file) | |||
trainer = build_trainer( | |||
name=Trainers.clip_multi_modal_embedding, default_args=args) | |||
trainer.train() | |||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | |||
os.listdir(os.path.join(WORKSPACE, 'output'))) | |||
shutil.rmtree(WORKSPACE) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -38,7 +38,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
shutil.rmtree(self.tmp_dir) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skip | |||
def test_trainer_cfg_class(self): | |||
dataset = MsDataset.load('clue', subset_name='tnews') | |||
train_dataset = dataset['train'] | |||
@@ -87,7 +87,7 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||
cfg['dataset'] = { | |||
'train': { | |||
'labels': label_enumerate_values, | |||
'first_sequence': 'first_sequence', | |||
'first_sequence': 'tokens', | |||
'label': 'labels', | |||
} | |||
} | |||
@@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase): | |||
'ocr_fudanvi_zh', | |||
subset_name='scene', | |||
namespace='modelscope', | |||
split='train[:200]', | |||
split='train[800:900]', | |||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||
eval_dataset=MsDataset.load( | |||
'ocr_fudanvi_zh', | |||
@@ -72,7 +72,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||
pipeline_sentence_similarity(output_dir) | |||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level') | |||
@unittest.skip | |||
def test_trainer_with_backbone_head(self): | |||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
kwargs = dict( | |||