| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:3b230497f6ca10be42aed92b86db435d74fd7306746a059b4ad1e0d6b0652806 | |||||
| size 35694 | |||||
| @@ -36,6 +36,7 @@ class Models(object): | |||||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | swinL_semantic_segmentation = 'swinL-semantic-segmentation' | ||||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| newcrfs_depth_estimation = 'newcrfs-depth-estimation' | |||||
| resnet50_bert = 'resnet50-bert' | resnet50_bert = 'resnet50-bert' | ||||
| referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | ||||
| fer = 'fer' | fer = 'fer' | ||||
| @@ -210,6 +211,7 @@ class Pipelines(object): | |||||
| video_summarization = 'googlenet_pgl_video_summarization' | video_summarization = 'googlenet_pgl_video_summarization' | ||||
| language_guided_video_summarization = 'clip-it-video-summarization' | language_guided_video_summarization = 'clip-it-video-summarization' | ||||
| image_semantic_segmentation = 'image-semantic-segmentation' | image_semantic_segmentation = 'image-semantic-segmentation' | ||||
| image_depth_estimation = 'image-depth-estimation' | |||||
| image_reid_person = 'passvitb-image-reid-person' | image_reid_person = 'passvitb-image-reid-person' | ||||
| image_inpainting = 'fft-inpainting' | image_inpainting = 'fft-inpainting' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| @@ -5,10 +5,10 @@ from abc import ABC, abstractmethod | |||||
| from typing import Any, Callable, Dict, List, Optional, Union | from typing import Any, Callable, Dict, List, Optional, Union | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models.builder import MODELS, build_model | |||||
| from modelscope.models.builder import build_model | |||||
| from modelscope.utils.checkpoint import save_checkpoint, save_pretrained | from modelscope.utils.checkpoint import save_checkpoint, save_pretrained | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile, Tasks | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile | |||||
| from modelscope.utils.device import verify_device | from modelscope.utils.device import verify_device | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -94,6 +94,10 @@ class Model(ABC): | |||||
| if prefetched is not None: | if prefetched is not None: | ||||
| kwargs.pop('model_prefetched') | kwargs.pop('model_prefetched') | ||||
| invoked_by = kwargs.get(Invoke.KEY) | |||||
| if invoked_by is not None: | |||||
| kwargs.pop(Invoke.KEY) | |||||
| if osp.exists(model_name_or_path): | if osp.exists(model_name_or_path): | ||||
| local_model_dir = model_name_or_path | local_model_dir = model_name_or_path | ||||
| else: | else: | ||||
| @@ -101,7 +105,13 @@ class Model(ABC): | |||||
| raise RuntimeError( | raise RuntimeError( | ||||
| 'Expecting model is pre-fetched locally, but is not found.' | 'Expecting model is pre-fetched locally, but is not found.' | ||||
| ) | ) | ||||
| local_model_dir = snapshot_download(model_name_or_path, revision) | |||||
| if invoked_by is not None: | |||||
| invoked_by = '%s/%s' % (Invoke.KEY, invoked_by) | |||||
| else: | |||||
| invoked_by = '%s/%s' % (Invoke.KEY, Invoke.PRETRAINED) | |||||
| local_model_dir = snapshot_download( | |||||
| model_name_or_path, revision, user_agent=invoked_by) | |||||
| logger.info(f'initialize model from {local_model_dir}') | logger.info(f'initialize model from {local_model_dir}') | ||||
| if cfg_dict is not None: | if cfg_dict is not None: | ||||
| cfg = cfg_dict | cfg = cfg_dict | ||||
| @@ -133,6 +143,7 @@ class Model(ABC): | |||||
| model.cfg = cfg | model.cfg = cfg | ||||
| model.name = model_name_or_path | model.name = model_name_or_path | ||||
| model.model_dir = local_model_dir | |||||
| return model | return model | ||||
| def save_pretrained(self, | def save_pretrained(self, | ||||
| @@ -224,8 +224,8 @@ class BodyKeypointsDetection3D(TorchModel): | |||||
| lst_pose2d_cannoical.append(pose2d_canonical[:, | lst_pose2d_cannoical.append(pose2d_canonical[:, | ||||
| i - pad:i + pad + 1]) | i - pad:i + pad + 1]) | ||||
| input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0) | |||||
| input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0) | |||||
| input_pose2d_rr = torch.cat(lst_pose2d_cannoical, axis=0) | |||||
| input_pose2d_cannoical = torch.cat(lst_pose2d_cannoical, axis=0) | |||||
| if self.cfg.model.MODEL.USE_CANONICAL_COORDS: | if self.cfg.model.MODEL.USE_CANONICAL_COORDS: | ||||
| input_pose2d_abs = input_pose2d_cannoical.clone() | input_pose2d_abs = input_pose2d_cannoical.clone() | ||||
| @@ -0,0 +1 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| @@ -0,0 +1 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| @@ -0,0 +1,215 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from .newcrf_layers import NewCRF | |||||
| from .swin_transformer import SwinTransformer | |||||
| from .uper_crf_head import PSP | |||||
| class NewCRFDepth(nn.Module): | |||||
| """ | |||||
| Depth network based on neural window FC-CRFs architecture. | |||||
| """ | |||||
| def __init__(self, | |||||
| version=None, | |||||
| inv_depth=False, | |||||
| pretrained=None, | |||||
| frozen_stages=-1, | |||||
| min_depth=0.1, | |||||
| max_depth=100.0, | |||||
| **kwargs): | |||||
| super().__init__() | |||||
| self.inv_depth = inv_depth | |||||
| self.with_auxiliary_head = False | |||||
| self.with_neck = False | |||||
| norm_cfg = dict(type='BN', requires_grad=True) | |||||
| # norm_cfg = dict(type='GN', requires_grad=True, num_groups=8) | |||||
| window_size = int(version[-2:]) | |||||
| if version[:-2] == 'base': | |||||
| embed_dim = 128 | |||||
| depths = [2, 2, 18, 2] | |||||
| num_heads = [4, 8, 16, 32] | |||||
| in_channels = [128, 256, 512, 1024] | |||||
| elif version[:-2] == 'large': | |||||
| embed_dim = 192 | |||||
| depths = [2, 2, 18, 2] | |||||
| num_heads = [6, 12, 24, 48] | |||||
| in_channels = [192, 384, 768, 1536] | |||||
| elif version[:-2] == 'tiny': | |||||
| embed_dim = 96 | |||||
| depths = [2, 2, 6, 2] | |||||
| num_heads = [3, 6, 12, 24] | |||||
| in_channels = [96, 192, 384, 768] | |||||
| backbone_cfg = dict( | |||||
| embed_dim=embed_dim, | |||||
| depths=depths, | |||||
| num_heads=num_heads, | |||||
| window_size=window_size, | |||||
| ape=False, | |||||
| drop_path_rate=0.3, | |||||
| patch_norm=True, | |||||
| use_checkpoint=False, | |||||
| frozen_stages=frozen_stages) | |||||
| embed_dim = 512 | |||||
| decoder_cfg = dict( | |||||
| in_channels=in_channels, | |||||
| in_index=[0, 1, 2, 3], | |||||
| pool_scales=(1, 2, 3, 6), | |||||
| channels=embed_dim, | |||||
| dropout_ratio=0.0, | |||||
| num_classes=32, | |||||
| norm_cfg=norm_cfg, | |||||
| align_corners=False) | |||||
| self.backbone = SwinTransformer(**backbone_cfg) | |||||
| # v_dim = decoder_cfg['num_classes'] * 4 | |||||
| win = 7 | |||||
| crf_dims = [128, 256, 512, 1024] | |||||
| v_dims = [64, 128, 256, embed_dim] | |||||
| self.crf3 = NewCRF( | |||||
| input_dim=in_channels[3], | |||||
| embed_dim=crf_dims[3], | |||||
| window_size=win, | |||||
| v_dim=v_dims[3], | |||||
| num_heads=32) | |||||
| self.crf2 = NewCRF( | |||||
| input_dim=in_channels[2], | |||||
| embed_dim=crf_dims[2], | |||||
| window_size=win, | |||||
| v_dim=v_dims[2], | |||||
| num_heads=16) | |||||
| self.crf1 = NewCRF( | |||||
| input_dim=in_channels[1], | |||||
| embed_dim=crf_dims[1], | |||||
| window_size=win, | |||||
| v_dim=v_dims[1], | |||||
| num_heads=8) | |||||
| self.crf0 = NewCRF( | |||||
| input_dim=in_channels[0], | |||||
| embed_dim=crf_dims[0], | |||||
| window_size=win, | |||||
| v_dim=v_dims[0], | |||||
| num_heads=4) | |||||
| self.decoder = PSP(**decoder_cfg) | |||||
| self.disp_head1 = DispHead(input_dim=crf_dims[0]) | |||||
| self.up_mode = 'bilinear' | |||||
| if self.up_mode == 'mask': | |||||
| self.mask_head = nn.Sequential( | |||||
| nn.Conv2d(crf_dims[0], 64, 3, padding=1), | |||||
| nn.ReLU(inplace=True), nn.Conv2d(64, 16 * 9, 1, padding=0)) | |||||
| self.min_depth = min_depth | |||||
| self.max_depth = max_depth | |||||
| self.init_weights(pretrained=pretrained) | |||||
| def init_weights(self, pretrained=None): | |||||
| """Initialize the weights in backbone and heads. | |||||
| Args: | |||||
| pretrained (str, optional): Path to pre-trained weights. | |||||
| Defaults to None. | |||||
| """ | |||||
| # print(f'== Load encoder backbone from: {pretrained}') | |||||
| self.backbone.init_weights(pretrained=pretrained) | |||||
| self.decoder.init_weights() | |||||
| if self.with_auxiliary_head: | |||||
| if isinstance(self.auxiliary_head, nn.ModuleList): | |||||
| for aux_head in self.auxiliary_head: | |||||
| aux_head.init_weights() | |||||
| else: | |||||
| self.auxiliary_head.init_weights() | |||||
| def upsample_mask(self, disp, mask): | |||||
| """ Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """ | |||||
| N, _, H, W = disp.shape | |||||
| mask = mask.view(N, 1, 9, 4, 4, H, W) | |||||
| mask = torch.softmax(mask, dim=2) | |||||
| up_disp = F.unfold(disp, kernel_size=3, padding=1) | |||||
| up_disp = up_disp.view(N, 1, 9, 1, 1, H, W) | |||||
| up_disp = torch.sum(mask * up_disp, dim=2) | |||||
| up_disp = up_disp.permute(0, 1, 4, 2, 5, 3) | |||||
| return up_disp.reshape(N, 1, 4 * H, 4 * W) | |||||
| def forward(self, imgs): | |||||
| feats = self.backbone(imgs) | |||||
| if self.with_neck: | |||||
| feats = self.neck(feats) | |||||
| ppm_out = self.decoder(feats) | |||||
| e3 = self.crf3(feats[3], ppm_out) | |||||
| e3 = nn.PixelShuffle(2)(e3) | |||||
| e2 = self.crf2(feats[2], e3) | |||||
| e2 = nn.PixelShuffle(2)(e2) | |||||
| e1 = self.crf1(feats[1], e2) | |||||
| e1 = nn.PixelShuffle(2)(e1) | |||||
| e0 = self.crf0(feats[0], e1) | |||||
| if self.up_mode == 'mask': | |||||
| mask = self.mask_head(e0) | |||||
| d1 = self.disp_head1(e0, 1) | |||||
| d1 = self.upsample_mask(d1, mask) | |||||
| else: | |||||
| d1 = self.disp_head1(e0, 4) | |||||
| depth = d1 * self.max_depth | |||||
| return depth | |||||
| class DispHead(nn.Module): | |||||
| def __init__(self, input_dim=100): | |||||
| super(DispHead, self).__init__() | |||||
| # self.norm1 = nn.BatchNorm2d(input_dim) | |||||
| self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1) | |||||
| # self.relu = nn.ReLU(inplace=True) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| def forward(self, x, scale): | |||||
| # x = self.relu(self.norm1(x)) | |||||
| x = self.sigmoid(self.conv1(x)) | |||||
| if scale > 1: | |||||
| x = upsample(x, scale_factor=scale) | |||||
| return x | |||||
| class DispUnpack(nn.Module): | |||||
| def __init__(self, input_dim=100, hidden_dim=128): | |||||
| super(DispUnpack, self).__init__() | |||||
| self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) | |||||
| self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| self.pixel_shuffle = nn.PixelShuffle(4) | |||||
| def forward(self, x, output_size): | |||||
| x = self.relu(self.conv1(x)) | |||||
| x = self.sigmoid(self.conv2(x)) # [b, 16, h/4, w/4] | |||||
| # x = torch.reshape(x, [x.shape[0], 1, x.shape[2]*4, x.shape[3]*4]) | |||||
| x = self.pixel_shuffle(x) | |||||
| return x | |||||
| def upsample(x, scale_factor=2, mode='bilinear', align_corners=False): | |||||
| """Upsample input tensor by a factor of 2 | |||||
| """ | |||||
| return F.interpolate( | |||||
| x, scale_factor=scale_factor, mode=mode, align_corners=align_corners) | |||||
| @@ -0,0 +1,504 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torch.utils.checkpoint as checkpoint | |||||
| from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |||||
| class Mlp(nn.Module): | |||||
| """ Multilayer perceptron.""" | |||||
| def __init__(self, | |||||
| in_features, | |||||
| hidden_features=None, | |||||
| out_features=None, | |||||
| act_layer=nn.GELU, | |||||
| drop=0.): | |||||
| super().__init__() | |||||
| out_features = out_features or in_features | |||||
| hidden_features = hidden_features or in_features | |||||
| self.fc1 = nn.Linear(in_features, hidden_features) | |||||
| self.act = act_layer() | |||||
| self.fc2 = nn.Linear(hidden_features, out_features) | |||||
| self.drop = nn.Dropout(drop) | |||||
| def forward(self, x): | |||||
| x = self.fc1(x) | |||||
| x = self.act(x) | |||||
| x = self.drop(x) | |||||
| x = self.fc2(x) | |||||
| x = self.drop(x) | |||||
| return x | |||||
| def window_partition(x, window_size): | |||||
| """ | |||||
| Args: | |||||
| x: (B, H, W, C) | |||||
| window_size (int): window size | |||||
| Returns: | |||||
| windows: (num_windows*B, window_size, window_size, C) | |||||
| """ | |||||
| B, H, W, C = x.shape | |||||
| x = x.view(B, H // window_size, window_size, W // window_size, window_size, | |||||
| C) | |||||
| windows = x.permute(0, 1, 3, 2, 4, | |||||
| 5).contiguous().view(-1, window_size, window_size, C) | |||||
| return windows | |||||
| def window_reverse(windows, window_size, H, W): | |||||
| """ | |||||
| Args: | |||||
| windows: (num_windows*B, window_size, window_size, C) | |||||
| window_size (int): Window size | |||||
| H (int): Height of image | |||||
| W (int): Width of image | |||||
| Returns: | |||||
| x: (B, H, W, C) | |||||
| """ | |||||
| B = int(windows.shape[0] / (H * W / window_size / window_size)) | |||||
| x = windows.view(B, H // window_size, W // window_size, window_size, | |||||
| window_size, -1) | |||||
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |||||
| return x | |||||
| class WindowAttention(nn.Module): | |||||
| """ Window based multi-head self attention (W-MSA) module with relative position bias. | |||||
| It supports both of shifted and non-shifted window. | |||||
| Args: | |||||
| dim (int): Number of input channels. | |||||
| window_size (tuple[int]): The height and width of the window. | |||||
| num_heads (int): Number of attention heads. | |||||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |||||
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |||||
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |||||
| """ | |||||
| def __init__(self, | |||||
| dim, | |||||
| window_size, | |||||
| num_heads, | |||||
| v_dim, | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| attn_drop=0., | |||||
| proj_drop=0.): | |||||
| super().__init__() | |||||
| self.dim = dim | |||||
| self.window_size = window_size # Wh, Ww | |||||
| self.num_heads = num_heads | |||||
| head_dim = dim // num_heads | |||||
| self.scale = qk_scale or head_dim**-0.5 | |||||
| # define a parameter table of relative position bias | |||||
| self.relative_position_bias_table = nn.Parameter( | |||||
| torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), | |||||
| num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||||
| # get pair-wise relative position index for each token inside the window | |||||
| coords_h = torch.arange(self.window_size[0]) | |||||
| coords_w = torch.arange(self.window_size[1]) | |||||
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||||
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||||
| relative_coords = coords_flatten[:, :, | |||||
| None] - coords_flatten[:, | |||||
| None, :] # 2, Wh*Ww, Wh*Ww | |||||
| relative_coords = relative_coords.permute( | |||||
| 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||||
| relative_coords[:, :, | |||||
| 0] += self.window_size[0] - 1 # shift to start from 0 | |||||
| relative_coords[:, :, 1] += self.window_size[1] - 1 | |||||
| relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |||||
| relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |||||
| self.register_buffer('relative_position_index', | |||||
| relative_position_index) | |||||
| self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) | |||||
| self.attn_drop = nn.Dropout(attn_drop) | |||||
| self.proj = nn.Linear(v_dim, v_dim) | |||||
| self.proj_drop = nn.Dropout(proj_drop) | |||||
| trunc_normal_(self.relative_position_bias_table, std=.02) | |||||
| self.softmax = nn.Softmax(dim=-1) | |||||
| def forward(self, x, v, mask=None): | |||||
| """ Forward function. | |||||
| Args: | |||||
| x: input features with shape of (num_windows*B, N, C) | |||||
| mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |||||
| """ | |||||
| B_, N, C = x.shape | |||||
| qk = self.qk(x).reshape(B_, N, 2, self.num_heads, | |||||
| C // self.num_heads).permute(2, 0, 3, 1, 4) | |||||
| q, k = qk[0], qk[ | |||||
| 1] # make torchscript happy (cannot use tensor as tuple) | |||||
| q = q * self.scale | |||||
| attn = (q @ k.transpose(-2, -1)) | |||||
| relative_position_bias = self.relative_position_bias_table[ | |||||
| self.relative_position_index.view(-1)].view( | |||||
| self.window_size[0] * self.window_size[1], | |||||
| self.window_size[0] * self.window_size[1], | |||||
| -1) # Wh*Ww,Wh*Ww,nH | |||||
| relative_position_bias = relative_position_bias.permute( | |||||
| 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||||
| attn = attn + relative_position_bias.unsqueeze(0) | |||||
| if mask is not None: | |||||
| nW = mask.shape[0] | |||||
| attn = attn.view(B_ // nW, nW, self.num_heads, N, | |||||
| N) + mask.unsqueeze(1).unsqueeze(0) | |||||
| attn = attn.view(-1, self.num_heads, N, N) | |||||
| attn = self.softmax(attn) | |||||
| else: | |||||
| attn = self.softmax(attn) | |||||
| attn = self.attn_drop(attn) | |||||
| # assert self.dim % v.shape[-1] == 0, "self.dim % v.shape[-1] != 0" | |||||
| # repeat_num = self.dim // v.shape[-1] | |||||
| # v = v.view(B_, N, self.num_heads // repeat_num, -1).transpose(1, 2).repeat(1, repeat_num, 1, 1) | |||||
| assert self.dim == v.shape[-1], 'self.dim != v.shape[-1]' | |||||
| v = v.view(B_, N, self.num_heads, -1).transpose(1, 2) | |||||
| x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |||||
| x = self.proj(x) | |||||
| x = self.proj_drop(x) | |||||
| return x | |||||
| class CRFBlock(nn.Module): | |||||
| """ CRF Block. | |||||
| Args: | |||||
| dim (int): Number of input channels. | |||||
| num_heads (int): Number of attention heads. | |||||
| window_size (int): Window size. | |||||
| shift_size (int): Shift size for SW-MSA. | |||||
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |||||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||||
| drop (float, optional): Dropout rate. Default: 0.0 | |||||
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||||
| drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |||||
| act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |||||
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||||
| """ | |||||
| def __init__(self, | |||||
| dim, | |||||
| num_heads, | |||||
| v_dim, | |||||
| window_size=7, | |||||
| shift_size=0, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| drop=0., | |||||
| attn_drop=0., | |||||
| drop_path=0., | |||||
| act_layer=nn.GELU, | |||||
| norm_layer=nn.LayerNorm): | |||||
| super().__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| self.v_dim = v_dim | |||||
| self.window_size = window_size | |||||
| self.shift_size = shift_size | |||||
| self.mlp_ratio = mlp_ratio | |||||
| assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' | |||||
| self.norm1 = norm_layer(dim) | |||||
| self.attn = WindowAttention( | |||||
| dim, | |||||
| window_size=to_2tuple(self.window_size), | |||||
| num_heads=num_heads, | |||||
| v_dim=v_dim, | |||||
| qkv_bias=qkv_bias, | |||||
| qk_scale=qk_scale, | |||||
| attn_drop=attn_drop, | |||||
| proj_drop=drop) | |||||
| self.drop_path = DropPath( | |||||
| drop_path) if drop_path > 0. else nn.Identity() | |||||
| self.norm2 = norm_layer(v_dim) | |||||
| mlp_hidden_dim = int(v_dim * mlp_ratio) | |||||
| self.mlp = Mlp( | |||||
| in_features=v_dim, | |||||
| hidden_features=mlp_hidden_dim, | |||||
| act_layer=act_layer, | |||||
| drop=drop) | |||||
| self.H = None | |||||
| self.W = None | |||||
| def forward(self, x, v, mask_matrix): | |||||
| """ Forward function. | |||||
| Args: | |||||
| x: Input feature, tensor size (B, H*W, C). | |||||
| H, W: Spatial resolution of the input feature. | |||||
| mask_matrix: Attention mask for cyclic shift. | |||||
| """ | |||||
| B, L, C = x.shape | |||||
| H, W = self.H, self.W | |||||
| assert L == H * W, 'input feature has wrong size' | |||||
| shortcut = x | |||||
| x = self.norm1(x) | |||||
| x = x.view(B, H, W, C) | |||||
| # pad feature maps to multiples of window size | |||||
| pad_l = pad_t = 0 | |||||
| pad_r = (self.window_size - W % self.window_size) % self.window_size | |||||
| pad_b = (self.window_size - H % self.window_size) % self.window_size | |||||
| x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||||
| v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||||
| _, Hp, Wp, _ = x.shape | |||||
| # cyclic shift | |||||
| if self.shift_size > 0: | |||||
| shifted_x = torch.roll( | |||||
| x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |||||
| shifted_v = torch.roll( | |||||
| v, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |||||
| attn_mask = mask_matrix | |||||
| else: | |||||
| shifted_x = x | |||||
| shifted_v = v | |||||
| attn_mask = None | |||||
| # partition windows | |||||
| x_windows = window_partition( | |||||
| shifted_x, self.window_size) # nW*B, window_size, window_size, C | |||||
| x_windows = x_windows.view(-1, self.window_size * self.window_size, | |||||
| C) # nW*B, window_size*window_size, C | |||||
| v_windows = window_partition( | |||||
| shifted_v, self.window_size) # nW*B, window_size, window_size, C | |||||
| v_windows = v_windows.view( | |||||
| -1, self.window_size * self.window_size, | |||||
| v_windows.shape[-1]) # nW*B, window_size*window_size, C | |||||
| # W-MSA/SW-MSA | |||||
| attn_windows = self.attn( | |||||
| x_windows, v_windows, | |||||
| mask=attn_mask) # nW*B, window_size*window_size, C | |||||
| # merge windows | |||||
| attn_windows = attn_windows.view(-1, self.window_size, | |||||
| self.window_size, self.v_dim) | |||||
| shifted_x = window_reverse(attn_windows, self.window_size, Hp, | |||||
| Wp) # B H' W' C | |||||
| # reverse cyclic shift | |||||
| if self.shift_size > 0: | |||||
| x = torch.roll( | |||||
| shifted_x, | |||||
| shifts=(self.shift_size, self.shift_size), | |||||
| dims=(1, 2)) | |||||
| else: | |||||
| x = shifted_x | |||||
| if pad_r > 0 or pad_b > 0: | |||||
| x = x[:, :H, :W, :].contiguous() | |||||
| x = x.view(B, H * W, self.v_dim) | |||||
| # FFN | |||||
| x = shortcut + self.drop_path(x) | |||||
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |||||
| return x | |||||
| class BasicCRFLayer(nn.Module): | |||||
| """ A basic NeWCRFs layer for one stage. | |||||
| Args: | |||||
| dim (int): Number of feature channels | |||||
| depth (int): Depths of this stage. | |||||
| num_heads (int): Number of attention head. | |||||
| window_size (int): Local window size. Default: 7. | |||||
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||||
| drop (float, optional): Dropout rate. Default: 0.0 | |||||
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||||
| drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |||||
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||||
| downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |||||
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||||
| """ | |||||
| def __init__(self, | |||||
| dim, | |||||
| depth, | |||||
| num_heads, | |||||
| v_dim, | |||||
| window_size=7, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| drop=0., | |||||
| attn_drop=0., | |||||
| drop_path=0., | |||||
| norm_layer=nn.LayerNorm, | |||||
| downsample=None, | |||||
| use_checkpoint=False): | |||||
| super().__init__() | |||||
| self.window_size = window_size | |||||
| self.shift_size = window_size // 2 | |||||
| self.depth = depth | |||||
| self.use_checkpoint = use_checkpoint | |||||
| # build blocks | |||||
| self.blocks = nn.ModuleList([ | |||||
| CRFBlock( | |||||
| dim=dim, | |||||
| num_heads=num_heads, | |||||
| v_dim=v_dim, | |||||
| window_size=window_size, | |||||
| shift_size=0 if (i % 2 == 0) else window_size // 2, | |||||
| mlp_ratio=mlp_ratio, | |||||
| qkv_bias=qkv_bias, | |||||
| qk_scale=qk_scale, | |||||
| drop=drop, | |||||
| attn_drop=attn_drop, | |||||
| drop_path=drop_path[i] | |||||
| if isinstance(drop_path, list) else drop_path, | |||||
| norm_layer=norm_layer) for i in range(depth) | |||||
| ]) | |||||
| # patch merging layer | |||||
| if downsample is not None: | |||||
| self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |||||
| else: | |||||
| self.downsample = None | |||||
| def forward(self, x, v, H, W): | |||||
| """ Forward function. | |||||
| Args: | |||||
| x: Input feature, tensor size (B, H*W, C). | |||||
| H, W: Spatial resolution of the input feature. | |||||
| """ | |||||
| # calculate attention mask for SW-MSA | |||||
| Hp = int(np.ceil(H / self.window_size)) * self.window_size | |||||
| Wp = int(np.ceil(W / self.window_size)) * self.window_size | |||||
| img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 | |||||
| h_slices = (slice(0, -self.window_size), | |||||
| slice(-self.window_size, | |||||
| -self.shift_size), slice(-self.shift_size, None)) | |||||
| w_slices = (slice(0, -self.window_size), | |||||
| slice(-self.window_size, | |||||
| -self.shift_size), slice(-self.shift_size, None)) | |||||
| cnt = 0 | |||||
| for h in h_slices: | |||||
| for w in w_slices: | |||||
| img_mask[:, h, w, :] = cnt | |||||
| cnt += 1 | |||||
| mask_windows = window_partition( | |||||
| img_mask, self.window_size) # nW, window_size, window_size, 1 | |||||
| mask_windows = mask_windows.view(-1, | |||||
| self.window_size * self.window_size) | |||||
| attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |||||
| attn_mask = attn_mask.masked_fill(attn_mask != 0, | |||||
| float(-100.0)).masked_fill( | |||||
| attn_mask == 0, float(0.0)) | |||||
| for blk in self.blocks: | |||||
| blk.H, blk.W = H, W | |||||
| if self.use_checkpoint: | |||||
| x = checkpoint.checkpoint(blk, x, attn_mask) | |||||
| else: | |||||
| x = blk(x, v, attn_mask) | |||||
| if self.downsample is not None: | |||||
| x_down = self.downsample(x, H, W) | |||||
| Wh, Ww = (H + 1) // 2, (W + 1) // 2 | |||||
| return x, H, W, x_down, Wh, Ww | |||||
| else: | |||||
| return x, H, W, x, H, W | |||||
| class NewCRF(nn.Module): | |||||
| def __init__(self, | |||||
| input_dim=96, | |||||
| embed_dim=96, | |||||
| v_dim=64, | |||||
| window_size=7, | |||||
| num_heads=4, | |||||
| depth=2, | |||||
| patch_size=4, | |||||
| in_chans=3, | |||||
| norm_layer=nn.LayerNorm, | |||||
| patch_norm=True): | |||||
| super().__init__() | |||||
| self.embed_dim = embed_dim | |||||
| self.patch_norm = patch_norm | |||||
| if input_dim != embed_dim: | |||||
| self.proj_x = nn.Conv2d(input_dim, embed_dim, 3, padding=1) | |||||
| else: | |||||
| self.proj_x = None | |||||
| if v_dim != embed_dim: | |||||
| self.proj_v = nn.Conv2d(v_dim, embed_dim, 3, padding=1) | |||||
| elif embed_dim % v_dim == 0: | |||||
| self.proj_v = None | |||||
| v_dim = embed_dim | |||||
| assert v_dim == embed_dim | |||||
| self.crf_layer = BasicCRFLayer( | |||||
| dim=embed_dim, | |||||
| depth=depth, | |||||
| num_heads=num_heads, | |||||
| v_dim=v_dim, | |||||
| window_size=window_size, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| drop=0., | |||||
| attn_drop=0., | |||||
| drop_path=0., | |||||
| norm_layer=norm_layer, | |||||
| downsample=None, | |||||
| use_checkpoint=False) | |||||
| layer = norm_layer(embed_dim) | |||||
| layer_name = 'norm_crf' | |||||
| self.add_module(layer_name, layer) | |||||
| def forward(self, x, v): | |||||
| if self.proj_x is not None: | |||||
| x = self.proj_x(x) | |||||
| if self.proj_v is not None: | |||||
| v = self.proj_v(v) | |||||
| Wh, Ww = x.size(2), x.size(3) | |||||
| x = x.flatten(2).transpose(1, 2) | |||||
| v = v.transpose(1, 2).transpose(2, 3) | |||||
| x_out, H, W, x, Wh, Ww = self.crf_layer(x, v, Wh, Ww) | |||||
| norm_layer = getattr(self, 'norm_crf') | |||||
| x_out = norm_layer(x_out) | |||||
| out = x_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, | |||||
| 2).contiguous() | |||||
| return out | |||||
| @@ -0,0 +1,272 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import os.path as osp | |||||
| import pkgutil | |||||
| import warnings | |||||
| from collections import OrderedDict | |||||
| from importlib import import_module | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torchvision | |||||
| from torch import distributed as dist | |||||
| from torch.nn import functional as F | |||||
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |||||
| from torch.utils import model_zoo | |||||
| TORCH_VERSION = torch.__version__ | |||||
| def resize(input, | |||||
| size=None, | |||||
| scale_factor=None, | |||||
| mode='nearest', | |||||
| align_corners=None, | |||||
| warning=True): | |||||
| if warning: | |||||
| if size is not None and align_corners: | |||||
| input_h, input_w = tuple(int(x) for x in input.shape[2:]) | |||||
| output_h, output_w = tuple(int(x) for x in size) | |||||
| if output_h > input_h or output_w > output_h: | |||||
| if ((output_h > 1 and output_w > 1 and input_h > 1 | |||||
| and input_w > 1) and (output_h - 1) % (input_h - 1) | |||||
| and (output_w - 1) % (input_w - 1)): | |||||
| warnings.warn( | |||||
| f'When align_corners={align_corners}, ' | |||||
| 'the output would more aligned if ' | |||||
| f'input size {(input_h, input_w)} is `x+1` and ' | |||||
| f'out size {(output_h, output_w)} is `nx+1`') | |||||
| if isinstance(size, torch.Size): | |||||
| size = tuple(int(x) for x in size) | |||||
| return F.interpolate(input, size, scale_factor, mode, align_corners) | |||||
| def normal_init(module, mean=0, std=1, bias=0): | |||||
| if hasattr(module, 'weight') and module.weight is not None: | |||||
| nn.init.normal_(module.weight, mean, std) | |||||
| if hasattr(module, 'bias') and module.bias is not None: | |||||
| nn.init.constant_(module.bias, bias) | |||||
| def is_module_wrapper(module): | |||||
| module_wrappers = (DataParallel, DistributedDataParallel) | |||||
| return isinstance(module, module_wrappers) | |||||
| def get_dist_info(): | |||||
| if TORCH_VERSION < '1.0': | |||||
| initialized = dist._initialized | |||||
| else: | |||||
| if dist.is_available(): | |||||
| initialized = dist.is_initialized() | |||||
| else: | |||||
| initialized = False | |||||
| if initialized: | |||||
| rank = dist.get_rank() | |||||
| world_size = dist.get_world_size() | |||||
| else: | |||||
| rank = 0 | |||||
| world_size = 1 | |||||
| return rank, world_size | |||||
| def load_state_dict(module, state_dict, strict=False, logger=None): | |||||
| """Load state_dict to a module. | |||||
| This method is modified from :meth:`torch.nn.Module.load_state_dict`. | |||||
| Default value for ``strict`` is set to ``False`` and the message for | |||||
| param mismatch will be shown even if strict is False. | |||||
| Args: | |||||
| module (Module): Module that receives the state_dict. | |||||
| state_dict (OrderedDict): Weights. | |||||
| strict (bool): whether to strictly enforce that the keys | |||||
| in :attr:`state_dict` match the keys returned by this module's | |||||
| :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. | |||||
| logger (:obj:`logging.Logger`, optional): Logger to log the error | |||||
| message. If not specified, print function will be used. | |||||
| """ | |||||
| unexpected_keys = [] | |||||
| all_missing_keys = [] | |||||
| err_msg = [] | |||||
| metadata = getattr(state_dict, '_metadata', None) | |||||
| state_dict = state_dict.copy() | |||||
| if metadata is not None: | |||||
| state_dict._metadata = metadata | |||||
| # use _load_from_state_dict to enable checkpoint version control | |||||
| def load(module, prefix=''): | |||||
| # recursively check parallel module in case that the model has a | |||||
| # complicated structure, e.g., nn.Module(nn.Module(DDP)) | |||||
| if is_module_wrapper(module): | |||||
| module = module.module | |||||
| local_metadata = {} if metadata is None else metadata.get( | |||||
| prefix[:-1], {}) | |||||
| module._load_from_state_dict(state_dict, prefix, local_metadata, True, | |||||
| all_missing_keys, unexpected_keys, | |||||
| err_msg) | |||||
| for name, child in module._modules.items(): | |||||
| if child is not None: | |||||
| load(child, prefix + name + '.') | |||||
| load(module) | |||||
| load = None # break load->load reference cycle | |||||
| # ignore "num_batches_tracked" of BN layers | |||||
| missing_keys = [ | |||||
| key for key in all_missing_keys if 'num_batches_tracked' not in key | |||||
| ] | |||||
| if unexpected_keys: | |||||
| err_msg.append('unexpected key in source ' | |||||
| f'state_dict: {", ".join(unexpected_keys)}\n') | |||||
| if missing_keys: | |||||
| err_msg.append( | |||||
| f'missing keys in source state_dict: {", ".join(missing_keys)}\n') | |||||
| rank, _ = get_dist_info() | |||||
| if len(err_msg) > 0 and rank == 0: | |||||
| err_msg.insert( | |||||
| 0, 'The model and loaded state dict do not match exactly\n') | |||||
| err_msg = '\n'.join(err_msg) | |||||
| if strict: | |||||
| raise RuntimeError(err_msg) | |||||
| elif logger is not None: | |||||
| logger.warning(err_msg) | |||||
| else: | |||||
| print(err_msg) | |||||
| def load_url_dist(url, model_dir=None): | |||||
| """In distributed setting, this function only download checkpoint at local | |||||
| rank 0.""" | |||||
| rank, world_size = get_dist_info() | |||||
| rank = int(os.environ.get('LOCAL_RANK', rank)) | |||||
| if rank == 0: | |||||
| checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |||||
| if world_size > 1: | |||||
| torch.distributed.barrier() | |||||
| if rank > 0: | |||||
| checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |||||
| return checkpoint | |||||
| def get_torchvision_models(): | |||||
| model_urls = dict() | |||||
| for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): | |||||
| if ispkg: | |||||
| continue | |||||
| _zoo = import_module(f'torchvision.models.{name}') | |||||
| if hasattr(_zoo, 'model_urls'): | |||||
| _urls = getattr(_zoo, 'model_urls') | |||||
| model_urls.update(_urls) | |||||
| return model_urls | |||||
| def _load_checkpoint(filename, map_location=None): | |||||
| """Load checkpoint from somewhere (modelzoo, file, url). | |||||
| Args: | |||||
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |||||
| ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |||||
| details. | |||||
| map_location (str | None): Same as :func:`torch.load`. Default: None. | |||||
| Returns: | |||||
| dict | OrderedDict: The loaded checkpoint. It can be either an | |||||
| OrderedDict storing model weights or a dict containing other | |||||
| information, which depends on the checkpoint. | |||||
| """ | |||||
| if filename.startswith('modelzoo://'): | |||||
| warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' | |||||
| 'use "torchvision://" instead') | |||||
| model_urls = get_torchvision_models() | |||||
| model_name = filename[11:] | |||||
| checkpoint = load_url_dist(model_urls[model_name]) | |||||
| else: | |||||
| if not osp.isfile(filename): | |||||
| raise IOError(f'{filename} is not a checkpoint file') | |||||
| checkpoint = torch.load(filename, map_location=map_location) | |||||
| return checkpoint | |||||
| def load_checkpoint(model, | |||||
| filename, | |||||
| map_location='cpu', | |||||
| strict=False, | |||||
| logger=None): | |||||
| """Load checkpoint from a file or URI. | |||||
| Args: | |||||
| model (Module): Module to load checkpoint. | |||||
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |||||
| ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |||||
| details. | |||||
| map_location (str): Same as :func:`torch.load`. | |||||
| strict (bool): Whether to allow different params for the model and | |||||
| checkpoint. | |||||
| logger (:mod:`logging.Logger` or None): The logger for error message. | |||||
| Returns: | |||||
| dict or OrderedDict: The loaded checkpoint. | |||||
| """ | |||||
| checkpoint = _load_checkpoint(filename, map_location) | |||||
| # OrderedDict is a subclass of dict | |||||
| if not isinstance(checkpoint, dict): | |||||
| raise RuntimeError( | |||||
| f'No state_dict found in checkpoint file {filename}') | |||||
| # get state_dict from checkpoint | |||||
| if 'state_dict' in checkpoint: | |||||
| state_dict = checkpoint['state_dict'] | |||||
| elif 'model' in checkpoint: | |||||
| state_dict = checkpoint['model'] | |||||
| else: | |||||
| state_dict = checkpoint | |||||
| # strip prefix of state_dict | |||||
| if list(state_dict.keys())[0].startswith('module.'): | |||||
| state_dict = {k[7:]: v for k, v in state_dict.items()} | |||||
| # for MoBY, load model of online branch | |||||
| if sorted(list(state_dict.keys()))[0].startswith('encoder'): | |||||
| state_dict = { | |||||
| k.replace('encoder.', ''): v | |||||
| for k, v in state_dict.items() if k.startswith('encoder.') | |||||
| } | |||||
| # reshape absolute position embedding | |||||
| if state_dict.get('absolute_pos_embed') is not None: | |||||
| absolute_pos_embed = state_dict['absolute_pos_embed'] | |||||
| N1, L, C1 = absolute_pos_embed.size() | |||||
| N2, C2, H, W = model.absolute_pos_embed.size() | |||||
| if N1 != N2 or C1 != C2 or L != H * W: | |||||
| logger.warning('Error in loading absolute_pos_embed, pass') | |||||
| else: | |||||
| state_dict['absolute_pos_embed'] = absolute_pos_embed.view( | |||||
| N2, H, W, C2).permute(0, 3, 1, 2) | |||||
| # interpolate position bias table if needed | |||||
| relative_position_bias_table_keys = [ | |||||
| k for k in state_dict.keys() if 'relative_position_bias_table' in k | |||||
| ] | |||||
| for table_key in relative_position_bias_table_keys: | |||||
| table_pretrained = state_dict[table_key] | |||||
| table_current = model.state_dict()[table_key] | |||||
| L1, nH1 = table_pretrained.size() | |||||
| L2, nH2 = table_current.size() | |||||
| if nH1 != nH2: | |||||
| logger.warning(f'Error in loading {table_key}, pass') | |||||
| else: | |||||
| if L1 != L2: | |||||
| S1 = int(L1**0.5) | |||||
| S2 = int(L2**0.5) | |||||
| table_pretrained_resized = F.interpolate( | |||||
| table_pretrained.permute(1, 0).view(1, nH1, S1, S1), | |||||
| size=(S2, S2), | |||||
| mode='bicubic') | |||||
| state_dict[table_key] = table_pretrained_resized.view( | |||||
| nH2, L2).permute(1, 0) | |||||
| # load state_dict | |||||
| load_state_dict(model, state_dict, strict, logger) | |||||
| return checkpoint | |||||
| @@ -0,0 +1,706 @@ | |||||
| # The implementation is adopted from Swin Transformer | |||||
| # made publicly available under the MIT License at https://github.com/microsoft/Swin-Transformer | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torch.utils.checkpoint as checkpoint | |||||
| from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |||||
| from .newcrf_utils import load_checkpoint | |||||
| class Mlp(nn.Module): | |||||
| """ Multilayer perceptron.""" | |||||
| def __init__(self, | |||||
| in_features, | |||||
| hidden_features=None, | |||||
| out_features=None, | |||||
| act_layer=nn.GELU, | |||||
| drop=0.): | |||||
| super().__init__() | |||||
| out_features = out_features or in_features | |||||
| hidden_features = hidden_features or in_features | |||||
| self.fc1 = nn.Linear(in_features, hidden_features) | |||||
| self.act = act_layer() | |||||
| self.fc2 = nn.Linear(hidden_features, out_features) | |||||
| self.drop = nn.Dropout(drop) | |||||
| def forward(self, x): | |||||
| x = self.fc1(x) | |||||
| x = self.act(x) | |||||
| x = self.drop(x) | |||||
| x = self.fc2(x) | |||||
| x = self.drop(x) | |||||
| return x | |||||
| def window_partition(x, window_size): | |||||
| """ | |||||
| Args: | |||||
| x: (B, H, W, C) | |||||
| window_size (int): window size | |||||
| Returns: | |||||
| windows: (num_windows*B, window_size, window_size, C) | |||||
| """ | |||||
| B, H, W, C = x.shape | |||||
| x = x.view(B, H // window_size, window_size, W // window_size, window_size, | |||||
| C) | |||||
| windows = x.permute(0, 1, 3, 2, 4, | |||||
| 5).contiguous().view(-1, window_size, window_size, C) | |||||
| return windows | |||||
| def window_reverse(windows, window_size, H, W): | |||||
| """ | |||||
| Args: | |||||
| windows: (num_windows*B, window_size, window_size, C) | |||||
| window_size (int): Window size | |||||
| H (int): Height of image | |||||
| W (int): Width of image | |||||
| Returns: | |||||
| x: (B, H, W, C) | |||||
| """ | |||||
| B = int(windows.shape[0] / (H * W / window_size / window_size)) | |||||
| x = windows.view(B, H // window_size, W // window_size, window_size, | |||||
| window_size, -1) | |||||
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |||||
| return x | |||||
| class WindowAttention(nn.Module): | |||||
| """ Window based multi-head self attention (W-MSA) module with relative position bias. | |||||
| It supports both of shifted and non-shifted window. | |||||
| Args: | |||||
| dim (int): Number of input channels. | |||||
| window_size (tuple[int]): The height and width of the window. | |||||
| num_heads (int): Number of attention heads. | |||||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |||||
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |||||
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |||||
| """ | |||||
| def __init__(self, | |||||
| dim, | |||||
| window_size, | |||||
| num_heads, | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| attn_drop=0., | |||||
| proj_drop=0.): | |||||
| super().__init__() | |||||
| self.dim = dim | |||||
| self.window_size = window_size # Wh, Ww | |||||
| self.num_heads = num_heads | |||||
| head_dim = dim // num_heads | |||||
| self.scale = qk_scale or head_dim**-0.5 | |||||
| # define a parameter table of relative position bias | |||||
| self.relative_position_bias_table = nn.Parameter( | |||||
| torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), | |||||
| num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||||
| # get pair-wise relative position index for each token inside the window | |||||
| coords_h = torch.arange(self.window_size[0]) | |||||
| coords_w = torch.arange(self.window_size[1]) | |||||
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||||
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||||
| relative_coords = coords_flatten[:, :, | |||||
| None] - coords_flatten[:, | |||||
| None, :] # 2, Wh*Ww, Wh*Ww | |||||
| relative_coords = relative_coords.permute( | |||||
| 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||||
| relative_coords[:, :, | |||||
| 0] += self.window_size[0] - 1 # shift to start from 0 | |||||
| relative_coords[:, :, 1] += self.window_size[1] - 1 | |||||
| relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |||||
| relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |||||
| self.register_buffer('relative_position_index', | |||||
| relative_position_index) | |||||
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||||
| self.attn_drop = nn.Dropout(attn_drop) | |||||
| self.proj = nn.Linear(dim, dim) | |||||
| self.proj_drop = nn.Dropout(proj_drop) | |||||
| trunc_normal_(self.relative_position_bias_table, std=.02) | |||||
| self.softmax = nn.Softmax(dim=-1) | |||||
| def forward(self, x, mask=None): | |||||
| """ Forward function. | |||||
| Args: | |||||
| x: input features with shape of (num_windows*B, N, C) | |||||
| mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |||||
| """ | |||||
| B_, N, C = x.shape | |||||
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, | |||||
| C // self.num_heads).permute(2, 0, 3, 1, 4) | |||||
| q, k, v = qkv[0], qkv[1], qkv[ | |||||
| 2] # make torchscript happy (cannot use tensor as tuple) | |||||
| q = q * self.scale | |||||
| attn = (q @ k.transpose(-2, -1)) | |||||
| relative_position_bias = self.relative_position_bias_table[ | |||||
| self.relative_position_index.view(-1)].view( | |||||
| self.window_size[0] * self.window_size[1], | |||||
| self.window_size[0] * self.window_size[1], | |||||
| -1) # Wh*Ww,Wh*Ww,nH | |||||
| relative_position_bias = relative_position_bias.permute( | |||||
| 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||||
| attn = attn + relative_position_bias.unsqueeze(0) | |||||
| if mask is not None: | |||||
| nW = mask.shape[0] | |||||
| attn = attn.view(B_ // nW, nW, self.num_heads, N, | |||||
| N) + mask.unsqueeze(1).unsqueeze(0) | |||||
| attn = attn.view(-1, self.num_heads, N, N) | |||||
| attn = self.softmax(attn) | |||||
| else: | |||||
| attn = self.softmax(attn) | |||||
| attn = self.attn_drop(attn) | |||||
| x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |||||
| x = self.proj(x) | |||||
| x = self.proj_drop(x) | |||||
| return x | |||||
| class SwinTransformerBlock(nn.Module): | |||||
| """ Swin Transformer Block. | |||||
| Args: | |||||
| dim (int): Number of input channels. | |||||
| num_heads (int): Number of attention heads. | |||||
| window_size (int): Window size. | |||||
| shift_size (int): Shift size for SW-MSA. | |||||
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |||||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||||
| drop (float, optional): Dropout rate. Default: 0.0 | |||||
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||||
| drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |||||
| act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |||||
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||||
| """ | |||||
| def __init__(self, | |||||
| dim, | |||||
| num_heads, | |||||
| window_size=7, | |||||
| shift_size=0, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| drop=0., | |||||
| attn_drop=0., | |||||
| drop_path=0., | |||||
| act_layer=nn.GELU, | |||||
| norm_layer=nn.LayerNorm): | |||||
| super().__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| self.window_size = window_size | |||||
| self.shift_size = shift_size | |||||
| self.mlp_ratio = mlp_ratio | |||||
| assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' | |||||
| self.norm1 = norm_layer(dim) | |||||
| self.attn = WindowAttention( | |||||
| dim, | |||||
| window_size=to_2tuple(self.window_size), | |||||
| num_heads=num_heads, | |||||
| qkv_bias=qkv_bias, | |||||
| qk_scale=qk_scale, | |||||
| attn_drop=attn_drop, | |||||
| proj_drop=drop) | |||||
| self.drop_path = DropPath( | |||||
| drop_path) if drop_path > 0. else nn.Identity() | |||||
| self.norm2 = norm_layer(dim) | |||||
| mlp_hidden_dim = int(dim * mlp_ratio) | |||||
| self.mlp = Mlp( | |||||
| in_features=dim, | |||||
| hidden_features=mlp_hidden_dim, | |||||
| act_layer=act_layer, | |||||
| drop=drop) | |||||
| self.H = None | |||||
| self.W = None | |||||
| def forward(self, x, mask_matrix): | |||||
| """ Forward function. | |||||
| Args: | |||||
| x: Input feature, tensor size (B, H*W, C). | |||||
| H, W: Spatial resolution of the input feature. | |||||
| mask_matrix: Attention mask for cyclic shift. | |||||
| """ | |||||
| B, L, C = x.shape | |||||
| H, W = self.H, self.W | |||||
| assert L == H * W, 'input feature has wrong size' | |||||
| shortcut = x | |||||
| x = self.norm1(x) | |||||
| x = x.view(B, H, W, C) | |||||
| # pad feature maps to multiples of window size | |||||
| pad_l = pad_t = 0 | |||||
| pad_r = (self.window_size - W % self.window_size) % self.window_size | |||||
| pad_b = (self.window_size - H % self.window_size) % self.window_size | |||||
| x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||||
| _, Hp, Wp, _ = x.shape | |||||
| # cyclic shift | |||||
| if self.shift_size > 0: | |||||
| shifted_x = torch.roll( | |||||
| x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |||||
| attn_mask = mask_matrix | |||||
| else: | |||||
| shifted_x = x | |||||
| attn_mask = None | |||||
| # partition windows | |||||
| x_windows = window_partition( | |||||
| shifted_x, self.window_size) # nW*B, window_size, window_size, C | |||||
| x_windows = x_windows.view(-1, self.window_size * self.window_size, | |||||
| C) # nW*B, window_size*window_size, C | |||||
| # W-MSA/SW-MSA | |||||
| attn_windows = self.attn( | |||||
| x_windows, mask=attn_mask) # nW*B, window_size*window_size, C | |||||
| # merge windows | |||||
| attn_windows = attn_windows.view(-1, self.window_size, | |||||
| self.window_size, C) | |||||
| shifted_x = window_reverse(attn_windows, self.window_size, Hp, | |||||
| Wp) # B H' W' C | |||||
| # reverse cyclic shift | |||||
| if self.shift_size > 0: | |||||
| x = torch.roll( | |||||
| shifted_x, | |||||
| shifts=(self.shift_size, self.shift_size), | |||||
| dims=(1, 2)) | |||||
| else: | |||||
| x = shifted_x | |||||
| if pad_r > 0 or pad_b > 0: | |||||
| x = x[:, :H, :W, :].contiguous() | |||||
| x = x.view(B, H * W, C) | |||||
| # FFN | |||||
| x = shortcut + self.drop_path(x) | |||||
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |||||
| return x | |||||
| class PatchMerging(nn.Module): | |||||
| """ Patch Merging Layer | |||||
| Args: | |||||
| dim (int): Number of input channels. | |||||
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||||
| """ | |||||
| def __init__(self, dim, norm_layer=nn.LayerNorm): | |||||
| super().__init__() | |||||
| self.dim = dim | |||||
| self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) | |||||
| self.norm = norm_layer(4 * dim) | |||||
| def forward(self, x, H, W): | |||||
| """ Forward function. | |||||
| Args: | |||||
| x: Input feature, tensor size (B, H*W, C). | |||||
| H, W: Spatial resolution of the input feature. | |||||
| """ | |||||
| B, L, C = x.shape | |||||
| assert L == H * W, 'input feature has wrong size' | |||||
| x = x.view(B, H, W, C) | |||||
| # padding | |||||
| pad_input = (H % 2 == 1) or (W % 2 == 1) | |||||
| if pad_input: | |||||
| x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) | |||||
| x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C | |||||
| x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C | |||||
| x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C | |||||
| x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C | |||||
| x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C | |||||
| x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C | |||||
| x = self.norm(x) | |||||
| x = self.reduction(x) | |||||
| return x | |||||
| class BasicLayer(nn.Module): | |||||
| """ A basic Swin Transformer layer for one stage. | |||||
| Args: | |||||
| dim (int): Number of feature channels | |||||
| depth (int): Depths of this stage. | |||||
| num_heads (int): Number of attention head. | |||||
| window_size (int): Local window size. Default: 7. | |||||
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||||
| drop (float, optional): Dropout rate. Default: 0.0 | |||||
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||||
| drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |||||
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||||
| downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |||||
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||||
| """ | |||||
| def __init__(self, | |||||
| dim, | |||||
| depth, | |||||
| num_heads, | |||||
| window_size=7, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| drop=0., | |||||
| attn_drop=0., | |||||
| drop_path=0., | |||||
| norm_layer=nn.LayerNorm, | |||||
| downsample=None, | |||||
| use_checkpoint=False): | |||||
| super().__init__() | |||||
| self.window_size = window_size | |||||
| self.shift_size = window_size // 2 | |||||
| self.depth = depth | |||||
| self.use_checkpoint = use_checkpoint | |||||
| # build blocks | |||||
| self.blocks = nn.ModuleList([ | |||||
| SwinTransformerBlock( | |||||
| dim=dim, | |||||
| num_heads=num_heads, | |||||
| window_size=window_size, | |||||
| shift_size=0 if (i % 2 == 0) else window_size // 2, | |||||
| mlp_ratio=mlp_ratio, | |||||
| qkv_bias=qkv_bias, | |||||
| qk_scale=qk_scale, | |||||
| drop=drop, | |||||
| attn_drop=attn_drop, | |||||
| drop_path=drop_path[i] | |||||
| if isinstance(drop_path, list) else drop_path, | |||||
| norm_layer=norm_layer) for i in range(depth) | |||||
| ]) | |||||
| # patch merging layer | |||||
| if downsample is not None: | |||||
| self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |||||
| else: | |||||
| self.downsample = None | |||||
| def forward(self, x, H, W): | |||||
| """ Forward function. | |||||
| Args: | |||||
| x: Input feature, tensor size (B, H*W, C). | |||||
| H, W: Spatial resolution of the input feature. | |||||
| """ | |||||
| # calculate attention mask for SW-MSA | |||||
| Hp = int(np.ceil(H / self.window_size)) * self.window_size | |||||
| Wp = int(np.ceil(W / self.window_size)) * self.window_size | |||||
| img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 | |||||
| h_slices = (slice(0, -self.window_size), | |||||
| slice(-self.window_size, | |||||
| -self.shift_size), slice(-self.shift_size, None)) | |||||
| w_slices = (slice(0, -self.window_size), | |||||
| slice(-self.window_size, | |||||
| -self.shift_size), slice(-self.shift_size, None)) | |||||
| cnt = 0 | |||||
| for h in h_slices: | |||||
| for w in w_slices: | |||||
| img_mask[:, h, w, :] = cnt | |||||
| cnt += 1 | |||||
| mask_windows = window_partition( | |||||
| img_mask, self.window_size) # nW, window_size, window_size, 1 | |||||
| mask_windows = mask_windows.view(-1, | |||||
| self.window_size * self.window_size) | |||||
| attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |||||
| attn_mask = attn_mask.masked_fill(attn_mask != 0, | |||||
| float(-100.0)).masked_fill( | |||||
| attn_mask == 0, float(0.0)) | |||||
| for blk in self.blocks: | |||||
| blk.H, blk.W = H, W | |||||
| if self.use_checkpoint: | |||||
| x = checkpoint.checkpoint(blk, x, attn_mask) | |||||
| else: | |||||
| x = blk(x, attn_mask) | |||||
| if self.downsample is not None: | |||||
| x_down = self.downsample(x, H, W) | |||||
| Wh, Ww = (H + 1) // 2, (W + 1) // 2 | |||||
| return x, H, W, x_down, Wh, Ww | |||||
| else: | |||||
| return x, H, W, x, H, W | |||||
| class PatchEmbed(nn.Module): | |||||
| """ Image to Patch Embedding | |||||
| Args: | |||||
| patch_size (int): Patch token size. Default: 4. | |||||
| in_chans (int): Number of input image channels. Default: 3. | |||||
| embed_dim (int): Number of linear projection output channels. Default: 96. | |||||
| norm_layer (nn.Module, optional): Normalization layer. Default: None | |||||
| """ | |||||
| def __init__(self, | |||||
| patch_size=4, | |||||
| in_chans=3, | |||||
| embed_dim=96, | |||||
| norm_layer=None): | |||||
| super().__init__() | |||||
| patch_size = to_2tuple(patch_size) | |||||
| self.patch_size = patch_size | |||||
| self.in_chans = in_chans | |||||
| self.embed_dim = embed_dim | |||||
| self.proj = nn.Conv2d( | |||||
| in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||||
| if norm_layer is not None: | |||||
| self.norm = norm_layer(embed_dim) | |||||
| else: | |||||
| self.norm = None | |||||
| def forward(self, x): | |||||
| """Forward function.""" | |||||
| # padding | |||||
| _, _, H, W = x.size() | |||||
| if W % self.patch_size[1] != 0: | |||||
| x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) | |||||
| if H % self.patch_size[0] != 0: | |||||
| x = F.pad(x, | |||||
| (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) | |||||
| x = self.proj(x) # B C Wh Ww | |||||
| if self.norm is not None: | |||||
| Wh, Ww = x.size(2), x.size(3) | |||||
| x = x.flatten(2).transpose(1, 2) | |||||
| x = self.norm(x) | |||||
| x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) | |||||
| return x | |||||
| class SwinTransformer(nn.Module): | |||||
| """ Swin Transformer backbone. | |||||
| A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - | |||||
| https://arxiv.org/pdf/2103.14030 | |||||
| Args: | |||||
| pretrain_img_size (int): Input image size for training the pretrained model, | |||||
| used in absolute postion embedding. Default 224. | |||||
| patch_size (int | tuple(int)): Patch size. Default: 4. | |||||
| in_chans (int): Number of input image channels. Default: 3. | |||||
| embed_dim (int): Number of linear projection output channels. Default: 96. | |||||
| depths (tuple[int]): Depths of each Swin Transformer stage. | |||||
| num_heads (tuple[int]): Number of attention head of each stage. | |||||
| window_size (int): Window size. Default: 7. | |||||
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||||
| qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True | |||||
| qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. | |||||
| drop_rate (float): Dropout rate. | |||||
| attn_drop_rate (float): Attention dropout rate. Default: 0. | |||||
| drop_path_rate (float): Stochastic depth rate. Default: 0.2. | |||||
| norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. | |||||
| ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. | |||||
| patch_norm (bool): If True, add normalization after patch embedding. Default: True. | |||||
| out_indices (Sequence[int]): Output from which stages. | |||||
| frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |||||
| -1 means not freezing any parameters. | |||||
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||||
| """ | |||||
| def __init__(self, | |||||
| pretrain_img_size=224, | |||||
| patch_size=4, | |||||
| in_chans=3, | |||||
| embed_dim=96, | |||||
| depths=[2, 2, 6, 2], | |||||
| num_heads=[3, 6, 12, 24], | |||||
| window_size=7, | |||||
| mlp_ratio=4., | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| drop_rate=0., | |||||
| attn_drop_rate=0., | |||||
| drop_path_rate=0.2, | |||||
| norm_layer=nn.LayerNorm, | |||||
| ape=False, | |||||
| patch_norm=True, | |||||
| out_indices=(0, 1, 2, 3), | |||||
| frozen_stages=-1, | |||||
| use_checkpoint=False): | |||||
| super().__init__() | |||||
| self.pretrain_img_size = pretrain_img_size | |||||
| self.num_layers = len(depths) | |||||
| self.embed_dim = embed_dim | |||||
| self.ape = ape | |||||
| self.patch_norm = patch_norm | |||||
| self.out_indices = out_indices | |||||
| self.frozen_stages = frozen_stages | |||||
| # split image into non-overlapping patches | |||||
| self.patch_embed = PatchEmbed( | |||||
| patch_size=patch_size, | |||||
| in_chans=in_chans, | |||||
| embed_dim=embed_dim, | |||||
| norm_layer=norm_layer if self.patch_norm else None) | |||||
| # absolute position embedding | |||||
| if self.ape: | |||||
| pretrain_img_size = to_2tuple(pretrain_img_size) | |||||
| patch_size = to_2tuple(patch_size) | |||||
| patches_resolution = [ | |||||
| pretrain_img_size[0] // patch_size[0], | |||||
| pretrain_img_size[1] // patch_size[1] | |||||
| ] | |||||
| self.absolute_pos_embed = nn.Parameter( | |||||
| torch.zeros(1, embed_dim, patches_resolution[0], | |||||
| patches_resolution[1])) | |||||
| trunc_normal_(self.absolute_pos_embed, std=.02) | |||||
| self.pos_drop = nn.Dropout(p=drop_rate) | |||||
| # stochastic depth | |||||
| dpr = [ | |||||
| x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |||||
| ] # stochastic depth decay rule | |||||
| # build layers | |||||
| self.layers = nn.ModuleList() | |||||
| for i_layer in range(self.num_layers): | |||||
| layer = BasicLayer( | |||||
| dim=int(embed_dim * 2**i_layer), | |||||
| depth=depths[i_layer], | |||||
| num_heads=num_heads[i_layer], | |||||
| window_size=window_size, | |||||
| mlp_ratio=mlp_ratio, | |||||
| qkv_bias=qkv_bias, | |||||
| qk_scale=qk_scale, | |||||
| drop=drop_rate, | |||||
| attn_drop=attn_drop_rate, | |||||
| drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |||||
| norm_layer=norm_layer, | |||||
| downsample=PatchMerging if | |||||
| (i_layer < self.num_layers - 1) else None, | |||||
| use_checkpoint=use_checkpoint) | |||||
| self.layers.append(layer) | |||||
| num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] | |||||
| self.num_features = num_features | |||||
| # add a norm layer for each output | |||||
| for i_layer in out_indices: | |||||
| layer = norm_layer(num_features[i_layer]) | |||||
| layer_name = f'norm{i_layer}' | |||||
| self.add_module(layer_name, layer) | |||||
| self._freeze_stages() | |||||
| def _freeze_stages(self): | |||||
| if self.frozen_stages >= 0: | |||||
| self.patch_embed.eval() | |||||
| for param in self.patch_embed.parameters(): | |||||
| param.requires_grad = False | |||||
| if self.frozen_stages >= 1 and self.ape: | |||||
| self.absolute_pos_embed.requires_grad = False | |||||
| if self.frozen_stages >= 2: | |||||
| self.pos_drop.eval() | |||||
| for i in range(0, self.frozen_stages - 1): | |||||
| m = self.layers[i] | |||||
| m.eval() | |||||
| for param in m.parameters(): | |||||
| param.requires_grad = False | |||||
| def init_weights(self, pretrained=None): | |||||
| """Initialize the weights in backbone. | |||||
| Args: | |||||
| pretrained (str, optional): Path to pre-trained weights. | |||||
| Defaults to None. | |||||
| """ | |||||
| def _init_weights(m): | |||||
| if isinstance(m, nn.Linear): | |||||
| trunc_normal_(m.weight, std=.02) | |||||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0) | |||||
| elif isinstance(m, nn.LayerNorm): | |||||
| nn.init.constant_(m.bias, 0) | |||||
| nn.init.constant_(m.weight, 1.0) | |||||
| if isinstance(pretrained, str): | |||||
| self.apply(_init_weights) | |||||
| # logger = get_root_logger() | |||||
| load_checkpoint(self, pretrained, strict=False) | |||||
| elif pretrained is None: | |||||
| self.apply(_init_weights) | |||||
| else: | |||||
| raise TypeError('pretrained must be a str or None') | |||||
| def forward(self, x): | |||||
| """Forward function.""" | |||||
| x = self.patch_embed(x) | |||||
| Wh, Ww = x.size(2), x.size(3) | |||||
| if self.ape: | |||||
| # interpolate the position embedding to the corresponding size | |||||
| absolute_pos_embed = F.interpolate( | |||||
| self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') | |||||
| x = (x + absolute_pos_embed).flatten(2).transpose(1, | |||||
| 2) # B Wh*Ww C | |||||
| else: | |||||
| x = x.flatten(2).transpose(1, 2) | |||||
| x = self.pos_drop(x) | |||||
| outs = [] | |||||
| for i in range(self.num_layers): | |||||
| layer = self.layers[i] | |||||
| x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) | |||||
| if i in self.out_indices: | |||||
| norm_layer = getattr(self, f'norm{i}') | |||||
| x_out = norm_layer(x_out) | |||||
| out = x_out.view(-1, H, W, | |||||
| self.num_features[i]).permute(0, 3, 1, | |||||
| 2).contiguous() | |||||
| outs.append(out) | |||||
| return tuple(outs) | |||||
| def train(self, mode=True): | |||||
| """Convert the model into training mode while keep layers freezed.""" | |||||
| super(SwinTransformer, self).train(mode) | |||||
| self._freeze_stages() | |||||
| @@ -0,0 +1,365 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from mmcv.cnn import ConvModule | |||||
| from .newcrf_utils import normal_init, resize | |||||
| class PPM(nn.ModuleList): | |||||
| """Pooling Pyramid Module used in PSPNet. | |||||
| Args: | |||||
| pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |||||
| Module. | |||||
| in_channels (int): Input channels. | |||||
| channels (int): Channels after modules, before conv_seg. | |||||
| conv_cfg (dict|None): Config of conv layers. | |||||
| norm_cfg (dict|None): Config of norm layers. | |||||
| act_cfg (dict): Config of activation layers. | |||||
| align_corners (bool): align_corners argument of F.interpolate. | |||||
| """ | |||||
| def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, | |||||
| act_cfg, align_corners): | |||||
| super(PPM, self).__init__() | |||||
| self.pool_scales = pool_scales | |||||
| self.align_corners = align_corners | |||||
| self.in_channels = in_channels | |||||
| self.channels = channels | |||||
| self.conv_cfg = conv_cfg | |||||
| self.norm_cfg = norm_cfg | |||||
| self.act_cfg = act_cfg | |||||
| for pool_scale in pool_scales: | |||||
| # == if batch size = 1, BN is not supported, change to GN | |||||
| if pool_scale == 1: | |||||
| norm_cfg = dict(type='GN', requires_grad=True, num_groups=256) | |||||
| self.append( | |||||
| nn.Sequential( | |||||
| nn.AdaptiveAvgPool2d(pool_scale), | |||||
| ConvModule( | |||||
| self.in_channels, | |||||
| self.channels, | |||||
| 1, | |||||
| conv_cfg=self.conv_cfg, | |||||
| norm_cfg=norm_cfg, | |||||
| act_cfg=self.act_cfg))) | |||||
| def forward(self, x): | |||||
| """Forward function.""" | |||||
| ppm_outs = [] | |||||
| for ppm in self: | |||||
| ppm_out = ppm(x) | |||||
| upsampled_ppm_out = resize( | |||||
| ppm_out, | |||||
| size=x.size()[2:], | |||||
| mode='bilinear', | |||||
| align_corners=self.align_corners) | |||||
| ppm_outs.append(upsampled_ppm_out) | |||||
| return ppm_outs | |||||
| class BaseDecodeHead(nn.Module): | |||||
| """Base class for BaseDecodeHead. | |||||
| Args: | |||||
| in_channels (int|Sequence[int]): Input channels. | |||||
| channels (int): Channels after modules, before conv_seg. | |||||
| num_classes (int): Number of classes. | |||||
| dropout_ratio (float): Ratio of dropout layer. Default: 0.1. | |||||
| conv_cfg (dict|None): Config of conv layers. Default: None. | |||||
| norm_cfg (dict|None): Config of norm layers. Default: None. | |||||
| act_cfg (dict): Config of activation layers. | |||||
| Default: dict(type='ReLU') | |||||
| in_index (int|Sequence[int]): Input feature index. Default: -1 | |||||
| input_transform (str|None): Transformation type of input features. | |||||
| Options: 'resize_concat', 'multiple_select', None. | |||||
| 'resize_concat': Multiple feature maps will be resize to the | |||||
| same size as first one and than concat together. | |||||
| Usually used in FCN head of HRNet. | |||||
| 'multiple_select': Multiple feature maps will be bundle into | |||||
| a list and passed into decode head. | |||||
| None: Only one select feature map is allowed. | |||||
| Default: None. | |||||
| loss_decode (dict): Config of decode loss. | |||||
| Default: dict(type='CrossEntropyLoss'). | |||||
| ignore_index (int | None): The label index to be ignored. When using | |||||
| masked BCE loss, ignore_index should be set to None. Default: 255 | |||||
| sampler (dict|None): The config of segmentation map sampler. | |||||
| Default: None. | |||||
| align_corners (bool): align_corners argument of F.interpolate. | |||||
| Default: False. | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| channels, | |||||
| *, | |||||
| num_classes, | |||||
| dropout_ratio=0.1, | |||||
| conv_cfg=None, | |||||
| norm_cfg=None, | |||||
| act_cfg=dict(type='ReLU'), | |||||
| in_index=-1, | |||||
| input_transform=None, | |||||
| loss_decode=dict( | |||||
| type='CrossEntropyLoss', | |||||
| use_sigmoid=False, | |||||
| loss_weight=1.0), | |||||
| ignore_index=255, | |||||
| sampler=None, | |||||
| align_corners=False): | |||||
| super(BaseDecodeHead, self).__init__() | |||||
| self._init_inputs(in_channels, in_index, input_transform) | |||||
| self.channels = channels | |||||
| self.num_classes = num_classes | |||||
| self.dropout_ratio = dropout_ratio | |||||
| self.conv_cfg = conv_cfg | |||||
| self.norm_cfg = norm_cfg | |||||
| self.act_cfg = act_cfg | |||||
| self.in_index = in_index | |||||
| # self.loss_decode = build_loss(loss_decode) | |||||
| self.ignore_index = ignore_index | |||||
| self.align_corners = align_corners | |||||
| # if sampler is not None: | |||||
| # self.sampler = build_pixel_sampler(sampler, context=self) | |||||
| # else: | |||||
| # self.sampler = None | |||||
| # self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) | |||||
| # self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1) | |||||
| if dropout_ratio > 0: | |||||
| self.dropout = nn.Dropout2d(dropout_ratio) | |||||
| else: | |||||
| self.dropout = None | |||||
| self.fp16_enabled = False | |||||
| def extra_repr(self): | |||||
| """Extra repr.""" | |||||
| s = f'input_transform={self.input_transform}, ' \ | |||||
| f'ignore_index={self.ignore_index}, ' \ | |||||
| f'align_corners={self.align_corners}' | |||||
| return s | |||||
| def _init_inputs(self, in_channels, in_index, input_transform): | |||||
| """Check and initialize input transforms. | |||||
| The in_channels, in_index and input_transform must match. | |||||
| Specifically, when input_transform is None, only single feature map | |||||
| will be selected. So in_channels and in_index must be of type int. | |||||
| When input_transform | |||||
| Args: | |||||
| in_channels (int|Sequence[int]): Input channels. | |||||
| in_index (int|Sequence[int]): Input feature index. | |||||
| input_transform (str|None): Transformation type of input features. | |||||
| Options: 'resize_concat', 'multiple_select', None. | |||||
| 'resize_concat': Multiple feature maps will be resize to the | |||||
| same size as first one and than concat together. | |||||
| Usually used in FCN head of HRNet. | |||||
| 'multiple_select': Multiple feature maps will be bundle into | |||||
| a list and passed into decode head. | |||||
| None: Only one select feature map is allowed. | |||||
| """ | |||||
| if input_transform is not None: | |||||
| assert input_transform in ['resize_concat', 'multiple_select'] | |||||
| self.input_transform = input_transform | |||||
| self.in_index = in_index | |||||
| if input_transform is not None: | |||||
| assert isinstance(in_channels, (list, tuple)) | |||||
| assert isinstance(in_index, (list, tuple)) | |||||
| assert len(in_channels) == len(in_index) | |||||
| if input_transform == 'resize_concat': | |||||
| self.in_channels = sum(in_channels) | |||||
| else: | |||||
| self.in_channels = in_channels | |||||
| else: | |||||
| assert isinstance(in_channels, int) | |||||
| assert isinstance(in_index, int) | |||||
| self.in_channels = in_channels | |||||
| def init_weights(self): | |||||
| """Initialize weights of classification layer.""" | |||||
| # normal_init(self.conv_seg, mean=0, std=0.01) | |||||
| # normal_init(self.conv1, mean=0, std=0.01) | |||||
| def _transform_inputs(self, inputs): | |||||
| """Transform inputs for decoder. | |||||
| Args: | |||||
| inputs (list[Tensor]): List of multi-level img features. | |||||
| Returns: | |||||
| Tensor: The transformed inputs | |||||
| """ | |||||
| if self.input_transform == 'resize_concat': | |||||
| inputs = [inputs[i] for i in self.in_index] | |||||
| upsampled_inputs = [ | |||||
| resize( | |||||
| input=x, | |||||
| size=inputs[0].shape[2:], | |||||
| mode='bilinear', | |||||
| align_corners=self.align_corners) for x in inputs | |||||
| ] | |||||
| inputs = torch.cat(upsampled_inputs, dim=1) | |||||
| elif self.input_transform == 'multiple_select': | |||||
| inputs = [inputs[i] for i in self.in_index] | |||||
| else: | |||||
| inputs = inputs[self.in_index] | |||||
| return inputs | |||||
| def forward(self, inputs): | |||||
| """Placeholder of forward function.""" | |||||
| pass | |||||
| def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): | |||||
| """Forward function for training. | |||||
| Args: | |||||
| inputs (list[Tensor]): List of multi-level img features. | |||||
| img_metas (list[dict]): List of image info dict where each dict | |||||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||||
| For details on the values of these keys see | |||||
| `mmseg/datasets/pipelines/formatting.py:Collect`. | |||||
| gt_semantic_seg (Tensor): Semantic segmentation masks | |||||
| used if the architecture supports semantic segmentation task. | |||||
| train_cfg (dict): The training config. | |||||
| Returns: | |||||
| dict[str, Tensor]: a dictionary of loss components | |||||
| """ | |||||
| seg_logits = self.forward(inputs) | |||||
| losses = self.losses(seg_logits, gt_semantic_seg) | |||||
| return losses | |||||
| def forward_test(self, inputs, img_metas, test_cfg): | |||||
| """Forward function for testing. | |||||
| Args: | |||||
| inputs (list[Tensor]): List of multi-level img features. | |||||
| img_metas (list[dict]): List of image info dict where each dict | |||||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||||
| For details on the values of these keys see | |||||
| `mmseg/datasets/pipelines/formatting.py:Collect`. | |||||
| test_cfg (dict): The testing config. | |||||
| Returns: | |||||
| Tensor: Output segmentation map. | |||||
| """ | |||||
| return self.forward(inputs) | |||||
| class UPerHead(BaseDecodeHead): | |||||
| def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): | |||||
| super(UPerHead, self).__init__( | |||||
| input_transform='multiple_select', **kwargs) | |||||
| # FPN Module | |||||
| self.lateral_convs = nn.ModuleList() | |||||
| self.fpn_convs = nn.ModuleList() | |||||
| for in_channels in self.in_channels: # skip the top layer | |||||
| l_conv = ConvModule( | |||||
| in_channels, | |||||
| self.channels, | |||||
| 1, | |||||
| conv_cfg=self.conv_cfg, | |||||
| norm_cfg=self.norm_cfg, | |||||
| act_cfg=self.act_cfg, | |||||
| inplace=True) | |||||
| fpn_conv = ConvModule( | |||||
| self.channels, | |||||
| self.channels, | |||||
| 3, | |||||
| padding=1, | |||||
| conv_cfg=self.conv_cfg, | |||||
| norm_cfg=self.norm_cfg, | |||||
| act_cfg=self.act_cfg, | |||||
| inplace=True) | |||||
| self.lateral_convs.append(l_conv) | |||||
| self.fpn_convs.append(fpn_conv) | |||||
| def forward(self, inputs): | |||||
| """Forward function.""" | |||||
| inputs = self._transform_inputs(inputs) | |||||
| # build laterals | |||||
| laterals = [ | |||||
| lateral_conv(inputs[i]) | |||||
| for i, lateral_conv in enumerate(self.lateral_convs) | |||||
| ] | |||||
| # laterals.append(self.psp_forward(inputs)) | |||||
| # build top-down path | |||||
| used_backbone_levels = len(laterals) | |||||
| for i in range(used_backbone_levels - 1, 0, -1): | |||||
| prev_shape = laterals[i - 1].shape[2:] | |||||
| laterals[i - 1] += resize( | |||||
| laterals[i], | |||||
| size=prev_shape, | |||||
| mode='bilinear', | |||||
| align_corners=self.align_corners) | |||||
| # build outputs | |||||
| fpn_outs = [ | |||||
| self.fpn_convs[i](laterals[i]) | |||||
| for i in range(used_backbone_levels - 1) | |||||
| ] | |||||
| # append psp feature | |||||
| fpn_outs.append(laterals[-1]) | |||||
| return fpn_outs[0] | |||||
| class PSP(BaseDecodeHead): | |||||
| """Unified Perceptual Parsing for Scene Understanding. | |||||
| This head is the implementation of `UPerNet | |||||
| <https://arxiv.org/abs/1807.10221>`_. | |||||
| Args: | |||||
| pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |||||
| Module applied on the last feature. Default: (1, 2, 3, 6). | |||||
| """ | |||||
| def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): | |||||
| super(PSP, self).__init__(input_transform='multiple_select', **kwargs) | |||||
| # PSP Module | |||||
| self.psp_modules = PPM( | |||||
| pool_scales, | |||||
| self.in_channels[-1], | |||||
| self.channels, | |||||
| conv_cfg=self.conv_cfg, | |||||
| norm_cfg=self.norm_cfg, | |||||
| act_cfg=self.act_cfg, | |||||
| align_corners=self.align_corners) | |||||
| self.bottleneck = ConvModule( | |||||
| self.in_channels[-1] + len(pool_scales) * self.channels, | |||||
| self.channels, | |||||
| 3, | |||||
| padding=1, | |||||
| conv_cfg=self.conv_cfg, | |||||
| norm_cfg=self.norm_cfg, | |||||
| act_cfg=self.act_cfg) | |||||
| def psp_forward(self, inputs): | |||||
| """Forward function of PSP module.""" | |||||
| x = inputs[-1] | |||||
| psp_outs = [x] | |||||
| psp_outs.extend(self.psp_modules(x)) | |||||
| psp_outs = torch.cat(psp_outs, dim=1) | |||||
| output = self.bottleneck(psp_outs) | |||||
| return output | |||||
| def forward(self, inputs): | |||||
| """Forward function.""" | |||||
| inputs = self._transform_inputs(inputs) | |||||
| return self.psp_forward(inputs) | |||||
| @@ -0,0 +1,53 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base.base_torch_model import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.cv.image_depth_estimation.networks.newcrf_depth import \ | |||||
| NewCRFDepth | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| @MODELS.register_module( | |||||
| Tasks.image_depth_estimation, module_name=Models.newcrfs_depth_estimation) | |||||
| class DepthEstimation(TorchModel): | |||||
| def __init__(self, model_dir: str, **kwargs): | |||||
| """str -- model file root.""" | |||||
| super().__init__(model_dir, **kwargs) | |||||
| # build model | |||||
| self.model = NewCRFDepth( | |||||
| version='large07', inv_depth=False, max_depth=10) | |||||
| # load model | |||||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| checkpoint = torch.load(model_path) | |||||
| state_dict = {} | |||||
| for k in checkpoint['model'].keys(): | |||||
| if k.startswith('module.'): | |||||
| state_dict[k[7:]] = checkpoint['model'][k] | |||||
| else: | |||||
| state_dict[k] = checkpoint['model'][k] | |||||
| self.model.load_state_dict(state_dict) | |||||
| self.model.eval() | |||||
| def forward(self, Inputs): | |||||
| return self.model(Inputs['imgs']) | |||||
| def postprocess(self, Inputs): | |||||
| depth_result = Inputs | |||||
| results = {OutputKeys.DEPTHS: depth_result} | |||||
| return results | |||||
| def inference(self, data): | |||||
| results = self.forward(data) | |||||
| return results | |||||
| @@ -509,8 +509,8 @@ def convert_weights(model: nn.Module): | |||||
| @MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | @MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | ||||
| class CLIPForMultiModalEmbedding(TorchModel): | class CLIPForMultiModalEmbedding(TorchModel): | ||||
| def __init__(self, model_dir, device_id=-1): | |||||
| super().__init__(model_dir=model_dir, device_id=device_id) | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||||
| # Initialize the model. | # Initialize the model. | ||||
| vision_model_config_file = '{}/vision_model_config.json'.format( | vision_model_config_file = '{}/vision_model_config.json'.format( | ||||
| @@ -9,7 +9,6 @@ import numpy as np | |||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models.base import Tensor, TorchModel | from modelscope.models.base import Tensor, TorchModel | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| @@ -730,7 +730,7 @@ def make_msa_feat_v2(batch): | |||||
| batch['cluster_profile'], | batch['cluster_profile'], | ||||
| deletion_mean_value, | deletion_mean_value, | ||||
| ] | ] | ||||
| batch['msa_feat'] = torch.concat(msa_feat, dim=-1) | |||||
| batch['msa_feat'] = torch.cat(msa_feat, dim=-1) | |||||
| return batch | return batch | ||||
| @@ -1320,7 +1320,7 @@ def get_contiguous_crop_idx( | |||||
| asym_offset + this_start + csz)) | asym_offset + this_start + csz)) | ||||
| asym_offset += ll | asym_offset += ll | ||||
| return torch.concat(crop_idxs) | |||||
| return torch.cat(crop_idxs) | |||||
| def get_spatial_crop_idx( | def get_spatial_crop_idx( | ||||
| @@ -217,7 +217,7 @@ class MSAAttention(nn.Module): | |||||
| if mask is not None else None) | if mask is not None else None) | ||||
| outputs.append( | outputs.append( | ||||
| self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) | self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) | ||||
| return torch.concat(outputs, dim=-3) | |||||
| return torch.cat(outputs, dim=-3) | |||||
| def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): | def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): | ||||
| m = self.layer_norm_m(m) | m = self.layer_norm_m(m) | ||||
| @@ -19,6 +19,7 @@ class OutputKeys(object): | |||||
| BOXES = 'boxes' | BOXES = 'boxes' | ||||
| KEYPOINTS = 'keypoints' | KEYPOINTS = 'keypoints' | ||||
| MASKS = 'masks' | MASKS = 'masks' | ||||
| DEPTHS = 'depths' | |||||
| TEXT = 'text' | TEXT = 'text' | ||||
| POLYGONS = 'polygons' | POLYGONS = 'polygons' | ||||
| OUTPUT = 'output' | OUTPUT = 'output' | ||||
| @@ -16,7 +16,7 @@ from modelscope.outputs import TASK_OUTPUTS | |||||
| from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type | from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type | ||||
| from modelscope.preprocessors import Preprocessor | from modelscope.preprocessors import Preprocessor | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Frameworks, ModelFile | |||||
| from modelscope.utils.constant import Frameworks, Invoke, ModelFile | |||||
| from modelscope.utils.device import (create_device, device_placement, | from modelscope.utils.device import (create_device, device_placement, | ||||
| verify_device) | verify_device) | ||||
| from modelscope.utils.hub import read_config, snapshot_download | from modelscope.utils.hub import read_config, snapshot_download | ||||
| @@ -47,8 +47,10 @@ class Pipeline(ABC): | |||||
| logger.info(f'initiate model from location {model}.') | logger.info(f'initiate model from location {model}.') | ||||
| # expecting model has been prefetched to local cache beforehand | # expecting model has been prefetched to local cache beforehand | ||||
| return Model.from_pretrained( | return Model.from_pretrained( | ||||
| model, model_prefetched=True, | |||||
| device=self.device_name) if is_model(model) else model | |||||
| model, | |||||
| device=self.device_name, | |||||
| model_prefetched=True, | |||||
| invoked_by=Invoke.PIPELINE) if is_model(model) else model | |||||
| else: | else: | ||||
| return model | return model | ||||
| @@ -231,7 +233,7 @@ class Pipeline(ABC): | |||||
| batch_data[k] = value_list | batch_data[k] = value_list | ||||
| for k in batch_data.keys(): | for k in batch_data.keys(): | ||||
| if isinstance(batch_data[k][0], torch.Tensor): | if isinstance(batch_data[k][0], torch.Tensor): | ||||
| batch_data[k] = torch.concat(batch_data[k]) | |||||
| batch_data[k] = torch.cat(batch_data[k]) | |||||
| return batch_data | return batch_data | ||||
| def _process_batch(self, input: List[Input], batch_size, | def _process_batch(self, input: List[Input], batch_size, | ||||
| @@ -383,15 +385,12 @@ class DistributedPipeline(Pipeline): | |||||
| preprocessor: Union[Preprocessor, List[Preprocessor]] = None, | preprocessor: Union[Preprocessor, List[Preprocessor]] = None, | ||||
| auto_collate=True, | auto_collate=True, | ||||
| **kwargs): | **kwargs): | ||||
| self.preprocessor = preprocessor | |||||
| super().__init__(model=model, preprocessor=preprocessor, kwargs=kwargs) | |||||
| self._model_prepare = False | self._model_prepare = False | ||||
| self._model_prepare_lock = Lock() | self._model_prepare_lock = Lock() | ||||
| self._auto_collate = auto_collate | self._auto_collate = auto_collate | ||||
| if os.path.exists(model): | |||||
| self.model_dir = model | |||||
| else: | |||||
| self.model_dir = snapshot_download(model) | |||||
| self.model_dir = self.model.model_dir | |||||
| self.cfg = read_config(self.model_dir) | self.cfg = read_config(self.model_dir) | ||||
| self.world_size = self.cfg.model.world_size | self.world_size = self.cfg.model.world_size | ||||
| self.model_pool = None | self.model_pool = None | ||||
| @@ -7,7 +7,7 @@ from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.utils.config import ConfigDict, check_config | from modelscope.utils.config import ConfigDict, check_config | ||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Tasks | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, Tasks | |||||
| from modelscope.utils.hub import read_config | from modelscope.utils.hub import read_config | ||||
| from modelscope.utils.registry import Registry, build_from_cfg | from modelscope.utils.registry import Registry, build_from_cfg | ||||
| from .base import Pipeline | from .base import Pipeline | ||||
| @@ -147,6 +147,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.image_segmentation: | Tasks.image_segmentation: | ||||
| (Pipelines.image_instance_segmentation, | (Pipelines.image_instance_segmentation, | ||||
| 'damo/cv_swin-b_image-instance-segmentation_coco'), | 'damo/cv_swin-b_image-instance-segmentation_coco'), | ||||
| Tasks.image_depth_estimation: | |||||
| (Pipelines.image_depth_estimation, | |||||
| 'damo/cv_newcrfs_image-depth-estimation_indoor'), | |||||
| Tasks.image_style_transfer: (Pipelines.image_style_transfer, | Tasks.image_style_transfer: (Pipelines.image_style_transfer, | ||||
| 'damo/cv_aams_style-transfer_damo'), | 'damo/cv_aams_style-transfer_damo'), | ||||
| Tasks.face_image_generation: (Pipelines.face_image_generation, | Tasks.face_image_generation: (Pipelines.face_image_generation, | ||||
| @@ -209,6 +212,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.referring_video_object_segmentation: | Tasks.referring_video_object_segmentation: | ||||
| (Pipelines.referring_video_object_segmentation, | (Pipelines.referring_video_object_segmentation, | ||||
| 'damo/cv_swin-t_referring_video-object-segmentation'), | 'damo/cv_swin-t_referring_video-object-segmentation'), | ||||
| Tasks.video_summarization: (Pipelines.video_summarization, | |||||
| 'damo/cv_googlenet_pgl-video-summarization'), | |||||
| } | } | ||||
| @@ -220,14 +225,19 @@ def normalize_model_input(model, model_revision): | |||||
| # skip revision download if model is a local directory | # skip revision download if model is a local directory | ||||
| if not os.path.exists(model): | if not os.path.exists(model): | ||||
| # note that if there is already a local copy, snapshot_download will check and skip downloading | # note that if there is already a local copy, snapshot_download will check and skip downloading | ||||
| model = snapshot_download(model, revision=model_revision) | |||||
| model = snapshot_download( | |||||
| model, | |||||
| revision=model_revision, | |||||
| user_agent={Invoke.KEY: Invoke.PIPELINE}) | |||||
| elif isinstance(model, list) and isinstance(model[0], str): | elif isinstance(model, list) and isinstance(model[0], str): | ||||
| for idx in range(len(model)): | for idx in range(len(model)): | ||||
| if is_official_hub_path( | if is_official_hub_path( | ||||
| model[idx], | model[idx], | ||||
| model_revision) and not os.path.exists(model[idx]): | model_revision) and not os.path.exists(model[idx]): | ||||
| model[idx] = snapshot_download( | model[idx] = snapshot_download( | ||||
| model[idx], revision=model_revision) | |||||
| model[idx], | |||||
| revision=model_revision, | |||||
| user_agent={Invoke.KEY: Invoke.PIPELINE}) | |||||
| return model | return model | ||||
| @@ -8,14 +8,13 @@ import torch | |||||
| from PIL import Image | from PIL import Image | ||||
| from torchvision import transforms | from torchvision import transforms | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.cv.animal_recognition import Bottleneck, ResNet | from modelscope.models.cv.animal_recognition import Bottleneck, ResNet | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import LoadImage | from modelscope.preprocessors import LoadImage | ||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.constant import Devices, ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -67,15 +66,10 @@ class AnimalRecognitionPipeline(Pipeline): | |||||
| filter_param(src_params, own_state) | filter_param(src_params, own_state) | ||||
| model.load_state_dict(own_state) | model.load_state_dict(own_state) | ||||
| self.model = resnest101(num_classes=8288) | |||||
| local_model_dir = model | |||||
| if osp.exists(model): | |||||
| local_model_dir = model | |||||
| else: | |||||
| local_model_dir = snapshot_download(model) | |||||
| self.local_path = local_model_dir | |||||
| self.local_path = self.model | |||||
| src_params = torch.load( | src_params = torch.load( | ||||
| osp.join(local_model_dir, 'pytorch_model.pt'), 'cpu') | |||||
| osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE), Devices.cpu) | |||||
| self.model = resnest101(num_classes=8288) | |||||
| load_pretrained(self.model, src_params) | load_pretrained(self.model, src_params) | ||||
| logger.info('load model done') | logger.info('load model done') | ||||
| @@ -120,8 +120,7 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| """ | """ | ||||
| super().__init__(model=model, **kwargs) | super().__init__(model=model, **kwargs) | ||||
| self.keypoint_model_3d = model if isinstance( | |||||
| model, BodyKeypointsDetection3D) else Model.from_pretrained(model) | |||||
| self.keypoint_model_3d = self.model | |||||
| self.keypoint_model_3d.eval() | self.keypoint_model_3d.eval() | ||||
| # init human body 2D keypoints detection pipeline | # init human body 2D keypoints detection pipeline | ||||
| @@ -11,7 +11,7 @@ from PIL import ImageFile | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.pipelines.util import is_official_hub_path | from modelscope.pipelines.util import is_official_hub_path | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile | |||||
| from modelscope.utils.device import create_device | from modelscope.utils.device import create_device | ||||
| @@ -37,7 +37,9 @@ class EasyCVPipeline(object): | |||||
| assert is_official_hub_path( | assert is_official_hub_path( | ||||
| model), 'Only support local model path and official hub path!' | model), 'Only support local model path and official hub path!' | ||||
| model_dir = snapshot_download( | model_dir = snapshot_download( | ||||
| model_id=model, revision=DEFAULT_MODEL_REVISION) | |||||
| model_id=model, | |||||
| revision=DEFAULT_MODEL_REVISION, | |||||
| user_agent={Invoke.KEY: Invoke.PIPELINE}) | |||||
| assert osp.isdir(model_dir) | assert osp.isdir(model_dir) | ||||
| model_files = glob.glob( | model_files = glob.glob( | ||||
| @@ -48,6 +50,7 @@ class EasyCVPipeline(object): | |||||
| model_path = model_files[0] | model_path = model_files[0] | ||||
| self.model_path = model_path | self.model_path = model_path | ||||
| self.model_dir = model_dir | |||||
| # get configuration file from source model dir | # get configuration file from source model dir | ||||
| self.config_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | self.config_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | ||||
| @@ -24,7 +24,6 @@ class HumanWholebodyKeypointsPipeline(EasyCVPipeline): | |||||
| model (str): model id on modelscope hub or local model path. | model (str): model id on modelscope hub or local model path. | ||||
| model_file_pattern (str): model file pattern. | model_file_pattern (str): model file pattern. | ||||
| """ | """ | ||||
| self.model_dir = model | |||||
| super(HumanWholebodyKeypointsPipeline, self).__init__( | super(HumanWholebodyKeypointsPipeline, self).__init__( | ||||
| model=model, | model=model, | ||||
| model_file_pattern=model_file_pattern, | model_file_pattern=model_file_pattern, | ||||
| @@ -8,7 +8,6 @@ import torch | |||||
| from PIL import Image | from PIL import Image | ||||
| from torchvision import transforms | from torchvision import transforms | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.cv.animal_recognition import resnet | from modelscope.models.cv.animal_recognition import resnet | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| @@ -67,16 +66,12 @@ class GeneralRecognitionPipeline(Pipeline): | |||||
| filter_param(src_params, own_state) | filter_param(src_params, own_state) | ||||
| model.load_state_dict(own_state) | model.load_state_dict(own_state) | ||||
| self.model = resnest101(num_classes=54092) | |||||
| local_model_dir = model | |||||
| device = 'cpu' | device = 'cpu' | ||||
| if osp.exists(model): | |||||
| local_model_dir = model | |||||
| else: | |||||
| local_model_dir = snapshot_download(model) | |||||
| self.local_path = local_model_dir | |||||
| self.local_path = self.model | |||||
| src_params = torch.load( | src_params = torch.load( | ||||
| osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), device) | |||||
| osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE), device) | |||||
| self.model = resnest101(num_classes=54092) | |||||
| load_pretrained(self.model, src_params) | load_pretrained(self.model, src_params) | ||||
| logger.info('load model done') | logger.info('load model done') | ||||
| @@ -21,7 +21,6 @@ class Hand2DKeypointsPipeline(EasyCVPipeline): | |||||
| model (str): model id on modelscope hub or local model path. | model (str): model id on modelscope hub or local model path. | ||||
| model_file_pattern (str): model file pattern. | model_file_pattern (str): model file pattern. | ||||
| """ | """ | ||||
| self.model_dir = model | |||||
| super(Hand2DKeypointsPipeline, self).__init__( | super(Hand2DKeypointsPipeline, self).__init__( | ||||
| model=model, | model=model, | ||||
| model_file_pattern=model_file_pattern, | model_file_pattern=model_file_pattern, | ||||
| @@ -1,5 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict, Union | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| @@ -25,22 +25,15 @@ class ImageClassificationPipeline(Pipeline): | |||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: [Preprocessor] = None, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | **kwargs): | ||||
| super().__init__(model=model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | assert isinstance(model, str) or isinstance(model, Model), \ | ||||
| 'model must be a single str or OfaForAllTasks' | 'model must be a single str or OfaForAllTasks' | ||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| pipe_model.to(get_device()) | |||||
| if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| self.model.to(get_device()) | |||||
| if preprocessor is None and isinstance(self.model, OfaForAllTasks): | |||||
| self.preprocessor = OfaPreprocessor(model_dir=self.model.model_dir) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| @@ -32,10 +32,8 @@ class ImageColorEnhancePipeline(Pipeline): | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| model = model if isinstance( | |||||
| model, ImageColorEnhance) else Model.from_pretrained(model) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.model.eval() | |||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| self._device = torch.device('cuda') | self._device = torch.device('cuda') | ||||
| @@ -32,17 +32,14 @@ class ImageDenoisePipeline(Pipeline): | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| model = model if isinstance( | |||||
| model, NAFNetForImageDenoise) else Model.from_pretrained(model) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.config = model.config | |||||
| self.model.eval() | |||||
| self.config = self.model.config | |||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| self._device = torch.device('cuda') | self._device = torch.device('cuda') | ||||
| else: | else: | ||||
| self._device = torch.device('cpu') | self._device = torch.device('cpu') | ||||
| self.model = model | |||||
| logger.info('load image denoise model done') | logger.info('load image denoise model done') | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Union | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Model, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_depth_estimation, module_name=Pipelines.image_depth_estimation) | |||||
| class ImageDepthEstimationPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a image depth estimation pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| logger.info('depth estimation model, pipeline init') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input).astype(np.float32) | |||||
| H, W = 480, 640 | |||||
| img = cv2.resize(img, [W, H]) | |||||
| img = img.transpose(2, 0, 1) / 255.0 | |||||
| imgs = img[None, ...] | |||||
| data = {'imgs': imgs} | |||||
| return data | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| results = self.model.inference(input) | |||||
| return results | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| results = self.model.postprocess(inputs) | |||||
| outputs = {OutputKeys.DEPTHS: results[OutputKeys.DEPTHS]} | |||||
| return outputs | |||||
| @@ -44,7 +44,7 @@ class LanguageGuidedVideoSummarizationPipeline(Pipeline): | |||||
| """ | """ | ||||
| super().__init__(model=model, auto_collate=False, **kwargs) | super().__init__(model=model, auto_collate=False, **kwargs) | ||||
| logger.info(f'loading model from {model}') | logger.info(f'loading model from {model}') | ||||
| self.model_dir = model | |||||
| self.model_dir = self.model.model_dir | |||||
| self.tmp_dir = kwargs.get('tmp_dir', None) | self.tmp_dir = kwargs.get('tmp_dir', None) | ||||
| if self.tmp_dir is None: | if self.tmp_dir is None: | ||||
| @@ -9,7 +9,6 @@ import PIL | |||||
| import torch | import torch | ||||
| from PIL import Image | from PIL import Image | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.cv.virual_tryon import SDAFNet_Tryon | from modelscope.models.cv.virual_tryon import SDAFNet_Tryon | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| @@ -52,17 +51,12 @@ class VirtualTryonPipeline(Pipeline): | |||||
| filter_param(src_params, own_state) | filter_param(src_params, own_state) | ||||
| model.load_state_dict(own_state) | model.load_state_dict(own_state) | ||||
| self.model = SDAFNet_Tryon(ref_in_channel=6).to(self.device) | |||||
| local_model_dir = model | |||||
| if osp.exists(model): | |||||
| local_model_dir = model | |||||
| else: | |||||
| local_model_dir = snapshot_download(model) | |||||
| self.local_path = local_model_dir | |||||
| self.local_path = self.model | |||||
| src_params = torch.load( | src_params = torch.load( | ||||
| osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), 'cpu') | |||||
| osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE), 'cpu') | |||||
| self.model = SDAFNet_Tryon(ref_in_channel=6).to(self.device) | |||||
| load_pretrained(self.model, src_params) | load_pretrained(self.model, src_params) | ||||
| self.model = self.model.eval() | |||||
| self.model.eval() | |||||
| self.size = 192 | self.size = 192 | ||||
| from torchvision import transforms | from torchvision import transforms | ||||
| self.test_transforms = transforms.Compose([ | self.test_transforms = transforms.Compose([ | ||||
| @@ -29,22 +29,13 @@ class ImageCaptioningPipeline(Pipeline): | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or OfaForAllTasks' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| if isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||||
| elif isinstance(pipe_model, MPlugForAllTasks): | |||||
| preprocessor = MPlugPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| if isinstance(self.model, OfaForAllTasks): | |||||
| self.preprocessor = OfaPreprocessor(self.model.model_dir) | |||||
| elif isinstance(self.model, MPlugForAllTasks): | |||||
| self.preprocessor = MPlugPreprocessor(self.model.model_dir) | |||||
| def _batch(self, data): | def _batch(self, data): | ||||
| if isinstance(self.model, OfaForAllTasks): | if isinstance(self.model, OfaForAllTasks): | ||||
| @@ -55,17 +46,17 @@ class ImageCaptioningPipeline(Pipeline): | |||||
| batch_data['samples'] = [d['samples'][0] for d in data] | batch_data['samples'] = [d['samples'][0] for d in data] | ||||
| batch_data['net_input'] = {} | batch_data['net_input'] = {} | ||||
| for k in data[0]['net_input'].keys(): | for k in data[0]['net_input'].keys(): | ||||
| batch_data['net_input'][k] = torch.concat( | |||||
| batch_data['net_input'][k] = torch.cat( | |||||
| [d['net_input'][k] for d in data]) | [d['net_input'][k] for d in data]) | ||||
| return batch_data | return batch_data | ||||
| elif isinstance(self.model, MPlugForAllTasks): | elif isinstance(self.model, MPlugForAllTasks): | ||||
| from transformers.tokenization_utils_base import BatchEncoding | from transformers.tokenization_utils_base import BatchEncoding | ||||
| batch_data = dict(train=data[0]['train']) | batch_data = dict(train=data[0]['train']) | ||||
| batch_data['image'] = torch.concat([d['image'] for d in data]) | |||||
| batch_data['image'] = torch.cat([d['image'] for d in data]) | |||||
| question = {} | question = {} | ||||
| for k in data[0]['question'].keys(): | for k in data[0]['question'].keys(): | ||||
| question[k] = torch.concat([d['question'][k] for d in data]) | |||||
| question[k] = torch.cat([d['question'][k] for d in data]) | |||||
| batch_data['question'] = BatchEncoding(question) | batch_data['question'] = BatchEncoding(question) | ||||
| return batch_data | return batch_data | ||||
| else: | else: | ||||
| @@ -28,19 +28,10 @@ class ImageTextRetrievalPipeline(Pipeline): | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| f'model must be a single str or Model, but got {type(model)}' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = MPlugPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| self.preprocessor = MPlugPreprocessor(self.model.model_dir) | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -28,21 +28,14 @@ class MultiModalEmbeddingPipeline(Pipeline): | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError('model must be a single str') | |||||
| pipe_model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| if isinstance(pipe_model, CLIPForMultiModalEmbedding): | |||||
| preprocessor = CLIPPreprocessor(pipe_model.model_dir) | |||||
| if isinstance(self.model, CLIPForMultiModalEmbedding): | |||||
| self.preprocessor = CLIPPreprocessor(self.model.model_dir) | |||||
| else: | else: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return self.model(self.preprocess(input)) | return self.model(self.preprocess(input)) | ||||
| @@ -28,20 +28,11 @@ class OcrRecognitionPipeline(Pipeline): | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or OfaForAllTasks' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| if isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| if isinstance(self.model, OfaForAllTasks): | |||||
| self.preprocessor = OfaPreprocessor(self.model.model_dir) | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -31,18 +31,10 @@ class TextToImageSynthesisPipeline(Pipeline): | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| device_id = 0 if torch.cuda.is_available() else -1 | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model, device_id=device_id) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError( | |||||
| f'expecting a Model instance or str, but get {type(model)}.') | |||||
| if preprocessor is None and isinstance(pipe_model, | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None and isinstance(self.model, | |||||
| OfaForTextToImageSynthesis): | OfaForTextToImageSynthesis): | ||||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| self.preprocessor = OfaPreprocessor(self.model.model_dir) | |||||
| def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: | def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: | ||||
| if self.preprocessor is not None: | if self.preprocessor is not None: | ||||
| @@ -1,5 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict, Union | |||||
| from typing import Any, Dict, Optional, Union | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.multi_modal import OfaForAllTasks | from modelscope.models.multi_modal import OfaForAllTasks | ||||
| @@ -18,26 +18,17 @@ class VisualEntailmentPipeline(Pipeline): | |||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: [Preprocessor] = None, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | **kwargs): | ||||
| """ | """ | ||||
| use `model` and `preprocessor` to create a visual entailment pipeline for prediction | use `model` and `preprocessor` to create a visual entailment pipeline for prediction | ||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or OfaForAllTasks' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| if preprocessor is None and isinstance(self.model, OfaForAllTasks): | |||||
| self.preprocessor = OfaPreprocessor(model_dir=self.model.model_dir) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| @@ -1,5 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict, Union | |||||
| from typing import Any, Dict, Optional, Union | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.multi_modal import OfaForAllTasks | from modelscope.models.multi_modal import OfaForAllTasks | ||||
| @@ -18,26 +18,17 @@ class VisualGroundingPipeline(Pipeline): | |||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: [Preprocessor] = None, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | **kwargs): | ||||
| """ | """ | ||||
| use `model` and `preprocessor` to create a visual grounding pipeline for prediction | use `model` and `preprocessor` to create a visual grounding pipeline for prediction | ||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or OfaForAllTasks' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.model.eval() | |||||
| if preprocessor is None and isinstance(self.model, OfaForAllTasks): | |||||
| self.preprocessor = OfaPreprocessor(model_dir=self.model.model_dir) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| @@ -31,15 +31,13 @@ class VisualQuestionAnsweringPipeline(Pipeline): | |||||
| model (MPlugForVisualQuestionAnswering): a model instance | model (MPlugForVisualQuestionAnswering): a model instance | ||||
| preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance | preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance | ||||
| """ | """ | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| if isinstance(model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(model.model_dir) | |||||
| elif isinstance(model, MPlugForAllTasks): | |||||
| preprocessor = MPlugPreprocessor(model.model_dir) | |||||
| model.model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| if preprocessor is None: | |||||
| if isinstance(self.model, OfaForAllTasks): | |||||
| self.preprocessor = OfaPreprocessor(self.model.model_dir) | |||||
| elif isinstance(self.model, MPlugForAllTasks): | |||||
| self.preprocessor = MPlugPreprocessor(self.model.model_dir) | |||||
| self.model.eval() | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -32,12 +32,10 @@ class ConversationalTextToSqlPipeline(Pipeline): | |||||
| preprocessor (ConversationalTextToSqlPreprocessor): | preprocessor (ConversationalTextToSqlPreprocessor): | ||||
| a preprocessor instance | a preprocessor instance | ||||
| """ | """ | ||||
| model = model if isinstance( | |||||
| model, StarForTextToSql) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = ConversationalTextToSqlPreprocessor(model.model_dir) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| if preprocessor is None: | |||||
| self.preprocessor = ConversationalTextToSqlPreprocessor( | |||||
| self.model.model_dir) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | ||||
| """process the prediction results | """process the prediction results | ||||
| @@ -30,13 +30,11 @@ class DialogIntentPredictionPipeline(Pipeline): | |||||
| or a SpaceForDialogIntent instance. | or a SpaceForDialogIntent instance. | ||||
| preprocessor (DialogIntentPredictionPreprocessor): An optional preprocessor instance. | preprocessor (DialogIntentPredictionPreprocessor): An optional preprocessor instance. | ||||
| """ | """ | ||||
| model = model if isinstance( | |||||
| model, SpaceForDialogIntent) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = DialogIntentPredictionPreprocessor(model.model_dir) | |||||
| self.model = model | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.categories = preprocessor.categories | |||||
| if preprocessor is None: | |||||
| self.preprocessor = DialogIntentPredictionPreprocessor( | |||||
| self.model.model_dir) | |||||
| self.categories = self.preprocessor.categories | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | ||||
| """process the prediction results | """process the prediction results | ||||
| @@ -29,13 +29,10 @@ class DialogModelingPipeline(Pipeline): | |||||
| or a SpaceForDialogModeling instance. | or a SpaceForDialogModeling instance. | ||||
| preprocessor (DialogModelingPreprocessor): An optional preprocessor instance. | preprocessor (DialogModelingPreprocessor): An optional preprocessor instance. | ||||
| """ | """ | ||||
| model = model if isinstance( | |||||
| model, SpaceForDialogModeling) else Model.from_pretrained(model) | |||||
| self.model = model | |||||
| if preprocessor is None: | |||||
| preprocessor = DialogModelingPreprocessor(model.model_dir) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.preprocessor = preprocessor | |||||
| if preprocessor is None: | |||||
| self.preprocessor = DialogModelingPreprocessor( | |||||
| self.model.model_dir) | |||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | ||||
| """process the prediction results | """process the prediction results | ||||
| @@ -31,16 +31,13 @@ class DialogStateTrackingPipeline(Pipeline): | |||||
| from the model hub, or a SpaceForDialogStateTracking instance. | from the model hub, or a SpaceForDialogStateTracking instance. | ||||
| preprocessor (DialogStateTrackingPreprocessor): An optional preprocessor instance. | preprocessor (DialogStateTrackingPreprocessor): An optional preprocessor instance. | ||||
| """ | """ | ||||
| model = model if isinstance( | |||||
| model, SpaceForDST) else Model.from_pretrained(model) | |||||
| self.model = model | |||||
| if preprocessor is None: | |||||
| preprocessor = DialogStateTrackingPreprocessor(model.model_dir) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| if preprocessor is None: | |||||
| self.preprocessor = DialogStateTrackingPreprocessor( | |||||
| self.model.model_dir) | |||||
| self.tokenizer = preprocessor.tokenizer | |||||
| self.config = preprocessor.config | |||||
| self.tokenizer = self.preprocessor.tokenizer | |||||
| self.config = self.preprocessor.config | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | ||||
| """process the prediction results | """process the prediction results | ||||
| @@ -31,27 +31,22 @@ class DocumentSegmentationPipeline(Pipeline): | |||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: DocumentSegmentationPreprocessor = None, | preprocessor: DocumentSegmentationPreprocessor = None, | ||||
| **kwargs): | **kwargs): | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| self.model_dir = model.model_dir | |||||
| self.model_cfg = model.forward() | |||||
| self.model_dir = self.model.model_dir | |||||
| self.model_cfg = self.model.forward() | |||||
| if self.model_cfg['type'] == 'bert': | if self.model_cfg['type'] == 'bert': | ||||
| config = BertConfig.from_pretrained(model.model_dir, num_labels=2) | |||||
| config = BertConfig.from_pretrained(self.model_dir, num_labels=2) | |||||
| elif self.model_cfg['type'] == 'ponet': | elif self.model_cfg['type'] == 'ponet': | ||||
| config = PoNetConfig.from_pretrained(model.model_dir, num_labels=2) | |||||
| config = PoNetConfig.from_pretrained(self.model_dir, num_labels=2) | |||||
| self.document_segmentation_model = model.build_with_config( | |||||
| self.document_segmentation_model = self.model.build_with_config( | |||||
| config=config) | config=config) | ||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = DocumentSegmentationPreprocessor( | |||||
| self.model_dir, config) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.preprocessor = preprocessor | |||||
| self.preprocessor = DocumentSegmentationPreprocessor( | |||||
| self.model.model_dir, config) | |||||
| def __call__( | def __call__( | ||||
| self, documents: Union[List[List[str]], List[str], | self, documents: Union[List[List[str]], List[str], | ||||
| @@ -21,12 +21,10 @@ class FaqQuestionAnsweringPipeline(Pipeline): | |||||
| model: Union[str, Model], | model: Union[str, Model], | ||||
| preprocessor: Preprocessor = None, | preprocessor: Preprocessor = None, | ||||
| **kwargs): | **kwargs): | ||||
| model = Model.from_pretrained(model) if isinstance(model, | |||||
| str) else model | |||||
| if preprocessor is None: | |||||
| preprocessor = Preprocessor.from_pretrained( | |||||
| model.model_dir, **kwargs) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| if preprocessor is None: | |||||
| self.preprocessor = Preprocessor.from_pretrained( | |||||
| self.model.model_dir, **kwargs) | |||||
| def _sanitize_parameters(self, **pipeline_parameters): | def _sanitize_parameters(self, **pipeline_parameters): | ||||
| return pipeline_parameters, pipeline_parameters, pipeline_parameters | return pipeline_parameters, pipeline_parameters, pipeline_parameters | ||||
| @@ -37,11 +35,11 @@ class FaqQuestionAnsweringPipeline(Pipeline): | |||||
| sentence_vecs = sentence_vecs.detach().tolist() | sentence_vecs = sentence_vecs.detach().tolist() | ||||
| return sentence_vecs | return sentence_vecs | ||||
| def forward(self, inputs: [list, Dict[str, Any]], | |||||
| def forward(self, inputs: Union[list, Dict[str, Any]], | |||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| return self.model(inputs) | return self.model(inputs) | ||||
| def postprocess(self, inputs: [list, Dict[str, Any]], | |||||
| def postprocess(self, inputs: Union[list, Dict[str, Any]], | |||||
| **postprocess_params) -> Dict[str, Any]: | **postprocess_params) -> Dict[str, Any]: | ||||
| scores = inputs['scores'] | scores = inputs['scores'] | ||||
| labels = [] | labels = [] | ||||
| @@ -46,21 +46,18 @@ class FeatureExtractionPipeline(Pipeline): | |||||
| """ | """ | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = NLPPreprocessor( | |||||
| model.model_dir, | |||||
| self.preprocessor = NLPPreprocessor( | |||||
| self.model.model_dir, | |||||
| padding=kwargs.pop('padding', False), | padding=kwargs.pop('padding', False), | ||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| self.preprocessor = preprocessor | |||||
| self.config = Config.from_file( | self.config = Config.from_file( | ||||
| os.path.join(model.model_dir, ModelFile.CONFIGURATION)) | |||||
| self.tokenizer = preprocessor.tokenizer | |||||
| os.path.join(self.model.model_dir, ModelFile.CONFIGURATION)) | |||||
| self.tokenizer = self.preprocessor.tokenizer | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -53,22 +53,18 @@ class FillMaskPipeline(Pipeline): | |||||
| If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is '<mask>'. | If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is '<mask>'. | ||||
| To view other examples plese check the tests/pipelines/test_fill_mask.py. | To view other examples plese check the tests/pipelines/test_fill_mask.py. | ||||
| """ | """ | ||||
| fill_mask_model = Model.from_pretrained(model) if isinstance( | |||||
| model, str) else model | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = Preprocessor.from_pretrained( | |||||
| fill_mask_model.model_dir, | |||||
| self.preprocessor = Preprocessor.from_pretrained( | |||||
| self.model.model_dir, | |||||
| first_sequence=first_sequence, | first_sequence=first_sequence, | ||||
| second_sequence=None, | second_sequence=None, | ||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| fill_mask_model.eval() | |||||
| assert hasattr( | |||||
| preprocessor, 'mask_id' | |||||
| ), 'The input preprocessor should have the mask_id attribute.' | |||||
| super().__init__( | |||||
| model=fill_mask_model, preprocessor=preprocessor, **kwargs) | |||||
| assert hasattr( | |||||
| self.preprocessor, 'mask_id' | |||||
| ), 'The input preprocessor should have the mask_id attribute.' | |||||
| self.model.eval() | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -25,15 +25,12 @@ class InformationExtractionPipeline(Pipeline): | |||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: Optional[Preprocessor] = None, | preprocessor: Optional[Preprocessor] = None, | ||||
| **kwargs): | **kwargs): | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = RelationExtractionPreprocessor( | |||||
| model.model_dir, | |||||
| self.preprocessor = RelationExtractionPreprocessor( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 512)) | sequence_length=kwargs.pop('sequence_length', 512)) | ||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -21,7 +21,7 @@ class MGLMTextSummarizationPipeline(Pipeline): | |||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[MGLMForTextSummarization, str], | model: Union[MGLMForTextSummarization, str], | ||||
| preprocessor: [Preprocessor] = None, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| *args, | *args, | ||||
| **kwargs): | **kwargs): | ||||
| model = MGLMForTextSummarization(model) if isinstance(model, | model = MGLMForTextSummarization(model) if isinstance(model, | ||||
| @@ -50,15 +50,12 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline): | |||||
| To view other examples plese check the tests/pipelines/test_named_entity_recognition.py. | To view other examples plese check the tests/pipelines/test_named_entity_recognition.py. | ||||
| """ | """ | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = TokenClassificationPreprocessor( | |||||
| model.model_dir, | |||||
| self.preprocessor = TokenClassificationPreprocessor( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| self.id2label = kwargs.get('id2label') | self.id2label = kwargs.get('id2label') | ||||
| if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | ||||
| self.id2label = self.preprocessor.id2label | self.id2label = self.preprocessor.id2label | ||||
| @@ -73,13 +70,11 @@ class NamedEntityRecognitionThaiPipeline(NamedEntityRecognitionPipeline): | |||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: Optional[Preprocessor] = None, | preprocessor: Optional[Preprocessor] = None, | ||||
| **kwargs): | **kwargs): | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = NERPreprocessorThai( | |||||
| model.model_dir, | |||||
| self.preprocessor = NERPreprocessorThai( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 512)) | sequence_length=kwargs.pop('sequence_length', 512)) | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| @@ -91,10 +86,8 @@ class NamedEntityRecognitionVietPipeline(NamedEntityRecognitionPipeline): | |||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: Optional[Preprocessor] = None, | preprocessor: Optional[Preprocessor] = None, | ||||
| **kwargs): | **kwargs): | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = NERPreprocessorViet( | |||||
| model.model_dir, | |||||
| self.preprocessor = NERPreprocessorViet( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 512)) | sequence_length=kwargs.pop('sequence_length', 512)) | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| @@ -32,14 +32,13 @@ class SentenceEmbeddingPipeline(Pipeline): | |||||
| the model if supplied. | the model if supplied. | ||||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | ||||
| """ | """ | ||||
| model = Model.from_pretrained(model) if isinstance(model, | |||||
| str) else model | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = Preprocessor.from_pretrained( | |||||
| model.model_dir if isinstance(model, Model) else model, | |||||
| self.preprocessor = Preprocessor.from_pretrained( | |||||
| self.model.model_dir | |||||
| if isinstance(self.model, Model) else model, | |||||
| first_sequence=first_sequence, | first_sequence=first_sequence, | ||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -1,5 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict, Union | |||||
| from typing import Any, Dict, Optional, Union | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.multi_modal import OfaForAllTasks | from modelscope.models.multi_modal import OfaForAllTasks | ||||
| @@ -18,7 +18,7 @@ class SummarizationPipeline(Pipeline): | |||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: [Preprocessor] = None, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | **kwargs): | ||||
| """Use `model` and `preprocessor` to create a Summarization pipeline for prediction. | """Use `model` and `preprocessor` to create a Summarization pipeline for prediction. | ||||
| @@ -27,19 +27,10 @@ class SummarizationPipeline(Pipeline): | |||||
| or a model id from the model hub, or a model instance. | or a model id from the model hub, or a model instance. | ||||
| preprocessor (Preprocessor): An optional preprocessor instance. | preprocessor (Preprocessor): An optional preprocessor instance. | ||||
| """ | """ | ||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or OfaForAllTasks' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| if preprocessor is None and isinstance(self.model, OfaForAllTasks): | |||||
| self.preprocessor = OfaPreprocessor(model_dir=self.model.model_dir) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| @@ -41,21 +41,22 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| preprocessor (TableQuestionAnsweringPreprocessor): a preprocessor instance | preprocessor (TableQuestionAnsweringPreprocessor): a preprocessor instance | ||||
| db (Database): a database to store tables in the database | db (Database): a database to store tables in the database | ||||
| """ | """ | ||||
| model = model if isinstance( | |||||
| model, TableQuestionAnswering) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = TableQuestionAnsweringPreprocessor(model.model_dir) | |||||
| self.preprocessor = TableQuestionAnsweringPreprocessor( | |||||
| self.model.model_dir) | |||||
| # initilize tokenizer | # initilize tokenizer | ||||
| self.tokenizer = BertTokenizer( | self.tokenizer = BertTokenizer( | ||||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||||
| os.path.join(self.model.model_dir, ModelFile.VOCAB_FILE)) | |||||
| # initialize database | # initialize database | ||||
| if db is None: | if db is None: | ||||
| self.db = Database( | self.db = Database( | ||||
| tokenizer=self.tokenizer, | tokenizer=self.tokenizer, | ||||
| table_file_path=os.path.join(model.model_dir, 'table.json'), | |||||
| syn_dict_file_path=os.path.join(model.model_dir, | |||||
| table_file_path=os.path.join(self.model.model_dir, | |||||
| 'table.json'), | |||||
| syn_dict_file_path=os.path.join(self.model.model_dir, | |||||
| 'synonym.txt')) | 'synonym.txt')) | ||||
| else: | else: | ||||
| self.db = db | self.db = db | ||||
| @@ -71,8 +72,6 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| self.schema_link_dict = constant.schema_link_dict | self.schema_link_dict = constant.schema_link_dict | ||||
| self.limit_dict = constant.limit_dict | self.limit_dict = constant.limit_dict | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def post_process_multi_turn(self, history_sql, result, table): | def post_process_multi_turn(self, history_sql, result, table): | ||||
| action = self.action_ops[result['action']] | action = self.action_ops[result['action']] | ||||
| headers = table['header_name'] | headers = table['header_name'] | ||||
| @@ -63,16 +63,14 @@ class Text2TextGenerationPipeline(Pipeline): | |||||
| To view other examples plese check the tests/pipelines/test_text_generation.py. | To view other examples plese check the tests/pipelines/test_text_generation.py. | ||||
| """ | """ | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = Text2TextGenerationPreprocessor( | |||||
| model.model_dir, | |||||
| self.preprocessor = Text2TextGenerationPreprocessor( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| self.tokenizer = preprocessor.tokenizer | |||||
| self.pipeline = model.pipeline.type | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.tokenizer = self.preprocessor.tokenizer | |||||
| self.pipeline = self.model.pipeline.type | |||||
| self.model.eval() | |||||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | ||||
| """ Provide specific preprocess for text2text generation pipeline in order to handl multi tasks | """ Provide specific preprocess for text2text generation pipeline in order to handl multi tasks | ||||
| @@ -53,25 +53,24 @@ class TextClassificationPipeline(Pipeline): | |||||
| NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' and 'second_sequence' | NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' and 'second_sequence' | ||||
| param will have no affection. | param will have no affection. | ||||
| """ | """ | ||||
| model = Model.from_pretrained(model) if isinstance(model, | |||||
| str) else model | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| if model.__class__.__name__ == 'OfaForAllTasks': | |||||
| preprocessor = Preprocessor.from_pretrained( | |||||
| model_name_or_path=model.model_dir, | |||||
| if self.model.__class__.__name__ == 'OfaForAllTasks': | |||||
| self.preprocessor = Preprocessor.from_pretrained( | |||||
| model_name_or_path=self.model.model_dir, | |||||
| type=Preprocessors.ofa_tasks_preprocessor, | type=Preprocessors.ofa_tasks_preprocessor, | ||||
| field=Fields.multi_modal) | field=Fields.multi_modal) | ||||
| else: | else: | ||||
| first_sequence = kwargs.pop('first_sequence', 'first_sequence') | first_sequence = kwargs.pop('first_sequence', 'first_sequence') | ||||
| second_sequence = kwargs.pop('second_sequence', None) | second_sequence = kwargs.pop('second_sequence', None) | ||||
| preprocessor = Preprocessor.from_pretrained( | |||||
| model if isinstance(model, str) else model.model_dir, | |||||
| self.preprocessor = Preprocessor.from_pretrained( | |||||
| self.model | |||||
| if isinstance(self.model, str) else self.model.model_dir, | |||||
| first_sequence=first_sequence, | first_sequence=first_sequence, | ||||
| second_sequence=second_sequence, | second_sequence=second_sequence, | ||||
| sequence_length=kwargs.pop('sequence_length', 512)) | sequence_length=kwargs.pop('sequence_length', 512)) | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.id2label = kwargs.get('id2label') | self.id2label = kwargs.get('id2label') | ||||
| if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | ||||
| self.id2label = self.preprocessor.id2label | self.id2label = self.preprocessor.id2label | ||||
| @@ -40,14 +40,12 @@ class TextErrorCorrectionPipeline(Pipeline): | |||||
| To view other examples plese check the tests/pipelines/test_text_error_correction.py. | To view other examples plese check the tests/pipelines/test_text_error_correction.py. | ||||
| """ | """ | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| model = model if isinstance( | |||||
| model, | |||||
| BartForTextErrorCorrection) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) | |||||
| self.vocab = preprocessor.vocab | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.preprocessor = TextErrorCorrectionPreprocessor( | |||||
| self.model.model_dir) | |||||
| self.vocab = self.preprocessor.vocab | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -51,15 +51,14 @@ class TextGenerationPipeline(Pipeline): | |||||
| To view other examples plese check the tests/pipelines/test_text_generation.py. | To view other examples plese check the tests/pipelines/test_text_generation.py. | ||||
| """ | """ | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| cfg = read_config(model.model_dir) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| cfg = read_config(self.model.model_dir) | |||||
| self.postprocessor = cfg.pop('postprocessor', 'decode') | self.postprocessor = cfg.pop('postprocessor', 'decode') | ||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor_cfg = cfg.preprocessor | preprocessor_cfg = cfg.preprocessor | ||||
| preprocessor_cfg.update({ | preprocessor_cfg.update({ | ||||
| 'model_dir': | 'model_dir': | ||||
| model.model_dir, | |||||
| self.model.model_dir, | |||||
| 'first_sequence': | 'first_sequence': | ||||
| first_sequence, | first_sequence, | ||||
| 'second_sequence': | 'second_sequence': | ||||
| @@ -67,9 +66,9 @@ class TextGenerationPipeline(Pipeline): | |||||
| 'sequence_length': | 'sequence_length': | ||||
| kwargs.pop('sequence_length', 128) | kwargs.pop('sequence_length', 128) | ||||
| }) | }) | ||||
| preprocessor = build_preprocessor(preprocessor_cfg, Fields.nlp) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.preprocessor = build_preprocessor(preprocessor_cfg, | |||||
| Fields.nlp) | |||||
| self.model.eval() | |||||
| def _sanitize_parameters(self, **pipeline_parameters): | def _sanitize_parameters(self, **pipeline_parameters): | ||||
| return {}, pipeline_parameters, {} | return {}, pipeline_parameters, {} | ||||
| @@ -32,14 +32,12 @@ class TextRankingPipeline(Pipeline): | |||||
| the model if supplied. | the model if supplied. | ||||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | ||||
| """ | """ | ||||
| model = Model.from_pretrained(model) if isinstance(model, | |||||
| str) else model | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = Preprocessor.from_pretrained( | |||||
| model.model_dir, | |||||
| self.preprocessor = Preprocessor.from_pretrained( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -39,15 +39,14 @@ class TokenClassificationPipeline(Pipeline): | |||||
| model (str or Model): A model instance or a model local dir or a model id in the model hub. | model (str or Model): A model instance or a model local dir or a model id in the model hub. | ||||
| preprocessor (Preprocessor): a preprocessor instance, must not be None. | preprocessor (Preprocessor): a preprocessor instance, must not be None. | ||||
| """ | """ | ||||
| model = Model.from_pretrained(model) if isinstance(model, | |||||
| str) else model | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = Preprocessor.from_pretrained( | |||||
| model.model_dir, | |||||
| self.preprocessor = Preprocessor.from_pretrained( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| self.id2label = kwargs.get('id2label') | self.id2label = kwargs.get('id2label') | ||||
| if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | ||||
| self.id2label = self.preprocessor.id2label | self.id2label = self.preprocessor.id2label | ||||
| @@ -27,10 +27,10 @@ class TranslationQualityEstimationPipeline(Pipeline): | |||||
| def __init__(self, model: str, device: str = 'gpu', **kwargs): | def __init__(self, model: str, device: str = 'gpu', **kwargs): | ||||
| super().__init__(model=model, device=device) | super().__init__(model=model, device=device) | ||||
| model_file = os.path.join(model, ModelFile.TORCH_MODEL_FILE) | |||||
| model_file = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||||
| with open(model_file, 'rb') as f: | with open(model_file, 'rb') as f: | ||||
| buffer = io.BytesIO(f.read()) | buffer = io.BytesIO(f.read()) | ||||
| self.tokenizer = XLMRobertaTokenizer.from_pretrained(model) | |||||
| self.tokenizer = XLMRobertaTokenizer.from_pretrained(self.model) | |||||
| self.model = torch.jit.load( | self.model = torch.jit.load( | ||||
| buffer, map_location=self.device).to(self.device) | buffer, map_location=self.device).to(self.device) | ||||
| @@ -49,14 +49,13 @@ class WordSegmentationPipeline(TokenClassificationPipeline): | |||||
| To view other examples plese check the tests/pipelines/test_word_segmentation.py. | To view other examples plese check the tests/pipelines/test_word_segmentation.py. | ||||
| """ | """ | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = TokenClassificationPreprocessor( | |||||
| model.model_dir, | |||||
| self.preprocessor = TokenClassificationPreprocessor( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| self.id2label = kwargs.get('id2label') | self.id2label = kwargs.get('id2label') | ||||
| if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | ||||
| self.id2label = self.preprocessor.id2label | self.id2label = self.preprocessor.id2label | ||||
| @@ -59,16 +59,14 @@ class ZeroShotClassificationPipeline(Pipeline): | |||||
| """ | """ | ||||
| assert isinstance(model, str) or isinstance(model, Model), \ | assert isinstance(model, str) or isinstance(model, Model), \ | ||||
| 'model must be a single str or Model' | 'model must be a single str or Model' | ||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.entailment_id = 0 | self.entailment_id = 0 | ||||
| self.contradiction_id = 2 | self.contradiction_id = 2 | ||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = ZeroShotClassificationPreprocessor( | |||||
| model.model_dir, | |||||
| self.preprocessor = ZeroShotClassificationPreprocessor( | |||||
| self.model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 512)) | sequence_length=kwargs.pop('sequence_length', 512)) | ||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model.eval() | |||||
| def _sanitize_parameters(self, **kwargs): | def _sanitize_parameters(self, **kwargs): | ||||
| preprocess_params = {} | preprocess_params = {} | ||||
| @@ -105,22 +105,16 @@ class ProteinStructurePipeline(Pipeline): | |||||
| >>> print(pipeline_ins(protein)) | >>> print(pipeline_ins(protein)) | ||||
| """ | """ | ||||
| import copy | |||||
| model_path = copy.deepcopy(model) if isinstance(model, str) else None | |||||
| cfg = read_config(model_path) # only model is str | |||||
| self.cfg = cfg | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.cfg = read_config(self.model.model_dir) | |||||
| self.config = model_config( | self.config = model_config( | ||||
| cfg['pipeline']['model_name']) # alphafold config | |||||
| model = model if isinstance( | |||||
| model, Model) else Model.from_pretrained(model_path) | |||||
| self.postprocessor = cfg.pop('postprocessor', None) | |||||
| self.cfg['pipeline']['model_name']) # alphafold config | |||||
| self.postprocessor = self.cfg.pop('postprocessor', None) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor_cfg = cfg.preprocessor | |||||
| preprocessor = build_preprocessor(preprocessor_cfg, Fields.science) | |||||
| model.eval() | |||||
| model.model.inference_mode() | |||||
| model.model_dir = model_path | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| preprocessor_cfg = self.cfg.preprocessor | |||||
| self.preprocessor = build_preprocessor(preprocessor_cfg, | |||||
| Fields.science) | |||||
| self.model.eval() | |||||
| def _sanitize_parameters(self, **pipeline_parameters): | def _sanitize_parameters(self, **pipeline_parameters): | ||||
| return pipeline_parameters, pipeline_parameters, pipeline_parameters | return pipeline_parameters, pipeline_parameters, pipeline_parameters | ||||
| @@ -6,7 +6,8 @@ from typing import Any, Dict, Optional, Sequence | |||||
| from modelscope.metainfo import Models, Preprocessors | from modelscope.metainfo import Models, Preprocessors | ||||
| from modelscope.utils.config import Config, ConfigDict | from modelscope.utils.config import Config, ConfigDict | ||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks | |||||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, | |||||
| ModeKeys, Tasks) | |||||
| from modelscope.utils.hub import read_config, snapshot_download | from modelscope.utils.hub import read_config, snapshot_download | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .builder import build_preprocessor | from .builder import build_preprocessor | ||||
| @@ -194,7 +195,9 @@ class Preprocessor(ABC): | |||||
| """ | """ | ||||
| if not os.path.exists(model_name_or_path): | if not os.path.exists(model_name_or_path): | ||||
| model_dir = snapshot_download( | model_dir = snapshot_download( | ||||
| model_name_or_path, revision=revision) | |||||
| model_name_or_path, | |||||
| revision=revision, | |||||
| user_agent={Invoke.KEY: Invoke.PREPROCESSOR}) | |||||
| else: | else: | ||||
| model_dir = model_name_or_path | model_dir = model_name_or_path | ||||
| if cfg_dict is None: | if cfg_dict is None: | ||||
| @@ -14,7 +14,8 @@ from modelscope.metainfo import Preprocessors | |||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks | |||||
| from modelscope.utils.constant import (Fields, Invoke, ModeKeys, ModelFile, | |||||
| Tasks) | |||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| from .ofa import * # noqa | from .ofa import * # noqa | ||||
| @@ -57,7 +58,7 @@ class OfaPreprocessor(Preprocessor): | |||||
| Tasks.auto_speech_recognition: OfaASRPreprocessor | Tasks.auto_speech_recognition: OfaASRPreprocessor | ||||
| } | } | ||||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | ||||
| model_dir) | |||||
| model_dir, user_agent={Invoke.KEY: Invoke.PREPROCESSOR}) | |||||
| self.cfg = Config.from_file( | self.cfg = Config.from_file( | ||||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | osp.join(model_dir, ModelFile.CONFIGURATION)) | ||||
| self.preprocess = preprocess_mapping[self.cfg.task]( | self.preprocess = preprocess_mapping[self.cfg.task]( | ||||
| @@ -131,7 +132,7 @@ class CLIPPreprocessor(Preprocessor): | |||||
| """ | """ | ||||
| super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | ||||
| model_dir) | |||||
| model_dir, user_agent={Invoke.KEY: Invoke.PREPROCESSOR}) | |||||
| self.mode = mode | self.mode = mode | ||||
| # text tokenizer | # text tokenizer | ||||
| from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer | from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer | ||||
| @@ -5,6 +5,7 @@ import random | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import librosa | |||||
| import soundfile as sf | import soundfile as sf | ||||
| import torch | import torch | ||||
| from fairseq.data.audio.feature_transforms import \ | from fairseq.data.audio.feature_transforms import \ | ||||
| @@ -54,9 +55,13 @@ class OfaASRPreprocessor(OfaBasePreprocessor): | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| speed = random.choice([0.9, 1.0, 1.1]) | speed = random.choice([0.9, 1.0, 1.1]) | ||||
| wav, sr = sf.read(self.column_map['wav']) | |||||
| wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True) | |||||
| fbank = self.prepare_fbank( | fbank = self.prepare_fbank( | ||||
| torch.tensor([wav], dtype=torch.float32), sr, speed, is_train=True) | |||||
| torch.tensor([wav], dtype=torch.float32), | |||||
| sr, | |||||
| speed, | |||||
| target_sample_rate=16000, | |||||
| is_train=True) | |||||
| fbank_mask = torch.tensor([True]) | fbank_mask = torch.tensor([True]) | ||||
| sample = { | sample = { | ||||
| 'fbank': fbank, | 'fbank': fbank, | ||||
| @@ -86,11 +91,12 @@ class OfaASRPreprocessor(OfaBasePreprocessor): | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| speed = 1.0 | speed = 1.0 | ||||
| wav, sr = sf.read(data[self.column_map['wav']]) | |||||
| wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True) | |||||
| fbank = self.prepare_fbank( | fbank = self.prepare_fbank( | ||||
| torch.tensor([wav], dtype=torch.float32), | torch.tensor([wav], dtype=torch.float32), | ||||
| sr, | sr, | ||||
| speed, | speed, | ||||
| target_sample_rate=16000, | |||||
| is_train=False) | is_train=False) | ||||
| fbank_mask = torch.tensor([True]) | fbank_mask = torch.tensor([True]) | ||||
| @@ -170,10 +170,15 @@ class OfaBasePreprocessor: | |||||
| else load_image(path_or_url_or_pil) | else load_image(path_or_url_or_pil) | ||||
| return image | return image | ||||
| def prepare_fbank(self, waveform, sample_rate, speed, is_train): | |||||
| waveform, _ = torchaudio.sox_effects.apply_effects_tensor( | |||||
| def prepare_fbank(self, | |||||
| waveform, | |||||
| sample_rate, | |||||
| speed, | |||||
| target_sample_rate=16000, | |||||
| is_train=False): | |||||
| waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( | |||||
| waveform, sample_rate, | waveform, sample_rate, | ||||
| [['speed', str(speed)], ['rate', str(sample_rate)]]) | |||||
| [['speed', str(speed)], ['rate', str(target_sample_rate)]]) | |||||
| _waveform, _ = convert_waveform( | _waveform, _ = convert_waveform( | ||||
| waveform, sample_rate, to_mono=True, normalize_volume=True) | waveform, sample_rate, to_mono=True, normalize_volume=True) | ||||
| # Kaldi compliance: 16-bit signed integers | # Kaldi compliance: 16-bit signed integers | ||||
| @@ -8,7 +8,6 @@ import torch | |||||
| from torch import nn as nn | from torch import nn as nn | ||||
| from torch import optim as optim | from torch import optim as optim | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.models import Model, TorchModel | from modelscope.models import Model, TorchModel | ||||
| from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset | from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset | ||||
| @@ -54,12 +53,8 @@ class KWSFarfieldTrainer(BaseTrainer): | |||||
| **kwargs): | **kwargs): | ||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| if os.path.exists(model): | |||||
| self.model_dir = model if os.path.isdir( | |||||
| model) else os.path.dirname(model) | |||||
| else: | |||||
| self.model_dir = snapshot_download( | |||||
| model, revision=model_revision) | |||||
| self.model_dir = self.get_or_download_model_dir( | |||||
| model, model_revision) | |||||
| if cfg_file is None: | if cfg_file is None: | ||||
| cfg_file = os.path.join(self.model_dir, | cfg_file = os.path.join(self.model_dir, | ||||
| ModelFile.CONFIGURATION) | ModelFile.CONFIGURATION) | ||||
| @@ -1,11 +1,14 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import time | import time | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Callable, Dict, List, Optional, Tuple, Union | from typing import Callable, Dict, List, Optional, Tuple, Union | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.trainers.builder import TRAINERS | from modelscope.trainers.builder import TRAINERS | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Invoke | |||||
| from .utils.log_buffer import LogBuffer | from .utils.log_buffer import LogBuffer | ||||
| @@ -32,6 +35,17 @@ class BaseTrainer(ABC): | |||||
| self.log_buffer = LogBuffer() | self.log_buffer = LogBuffer() | ||||
| self.timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | self.timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | ||||
| def get_or_download_model_dir(self, model, model_revision=None): | |||||
| if os.path.exists(model): | |||||
| model_cache_dir = model if os.path.isdir( | |||||
| model) else os.path.dirname(model) | |||||
| else: | |||||
| model_cache_dir = snapshot_download( | |||||
| model, | |||||
| revision=model_revision, | |||||
| user_agent={Invoke.KEY: Invoke.TRAINER}) | |||||
| return model_cache_dir | |||||
| @abstractmethod | @abstractmethod | ||||
| def train(self, *args, **kwargs): | def train(self, *args, **kwargs): | ||||
| """ Train (and evaluate) process | """ Train (and evaluate) process | ||||
| @@ -20,7 +20,7 @@ from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.optimizer.builder import build_optimizer | from modelscope.trainers.optimizer.builder import build_optimizer | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | ||||
| ModeKeys) | |||||
| Invoke, ModeKeys) | |||||
| from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule | from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule | ||||
| @@ -52,7 +52,8 @@ class CLIPTrainer(EpochBasedTrainer): | |||||
| model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | ||||
| seed: int = 42, | seed: int = 42, | ||||
| **kwargs): | **kwargs): | ||||
| model = Model.from_pretrained(model, revision=model_revision) | |||||
| model = Model.from_pretrained( | |||||
| model, revision=model_revision, invoked_by=Invoke.TRAINER) | |||||
| # for training & eval, we convert the model from FP16 back to FP32 | # for training & eval, we convert the model from FP16 back to FP32 | ||||
| # to compatible with modelscope amp training | # to compatible with modelscope amp training | ||||
| convert_models_to_fp32(model) | convert_models_to_fp32(model) | ||||
| @@ -23,7 +23,7 @@ from modelscope.trainers.optimizer.builder import build_optimizer | |||||
| from modelscope.trainers.parallel.utils import is_parallel | from modelscope.trainers.parallel.utils import is_parallel | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | ||||
| ModeKeys) | |||||
| Invoke, ModeKeys) | |||||
| from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | ||||
| get_schedule) | get_schedule) | ||||
| @@ -49,7 +49,8 @@ class OFATrainer(EpochBasedTrainer): | |||||
| model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | ||||
| seed: int = 42, | seed: int = 42, | ||||
| **kwargs): | **kwargs): | ||||
| model = Model.from_pretrained(model, revision=model_revision) | |||||
| model = Model.from_pretrained( | |||||
| model, revision=model_revision, invoked_by=Invoke.TRAINER) | |||||
| model_dir = model.model_dir | model_dir = model.model_dir | ||||
| self.cfg_modify_fn = cfg_modify_fn | self.cfg_modify_fn = cfg_modify_fn | ||||
| cfg = self.rebuild_config(Config.from_file(cfg_file)) | cfg = self.rebuild_config(Config.from_file(cfg_file)) | ||||
| @@ -7,21 +7,17 @@ from typing import Callable, Dict, Optional | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torchvision.datasets as datasets | |||||
| import torchvision.transforms as transforms | |||||
| from sklearn.metrics import confusion_matrix | from sklearn.metrics import confusion_matrix | ||||
| from torch.optim import AdamW | |||||
| from torch.utils.data import DataLoader, Dataset | from torch.utils.data import DataLoader, Dataset | ||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers.base import BaseTrainer | from modelscope.trainers.base import BaseTrainer | ||||
| from modelscope.trainers.builder import TRAINERS | from modelscope.trainers.builder import TRAINERS | ||||
| from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||||
| get_optimizer, train_mapping, val_mapping) | |||||
| from modelscope.trainers.multi_modal.team.team_trainer_utils import \ | |||||
| get_optimizer | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import DownloadMode, ModeKeys | |||||
| from modelscope.utils.constant import Invoke | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -36,7 +32,7 @@ class TEAMImgClsTrainer(BaseTrainer): | |||||
| super().__init__(cfg_file) | super().__init__(cfg_file) | ||||
| self.cfg = Config.from_file(cfg_file) | self.cfg = Config.from_file(cfg_file) | ||||
| team_model = Model.from_pretrained(model) | |||||
| team_model = Model.from_pretrained(model, invoked_by=Invoke.TRAINER) | |||||
| image_model = team_model.model.image_model.vision_transformer | image_model = team_model.model.image_model.vision_transformer | ||||
| classification_model = nn.Sequential( | classification_model = nn.Sequential( | ||||
| OrderedDict([('encoder', image_model), | OrderedDict([('encoder', image_model), | ||||
| @@ -24,8 +24,7 @@ logger = get_logger() | |||||
| class CsanmtTranslationTrainer(BaseTrainer): | class CsanmtTranslationTrainer(BaseTrainer): | ||||
| def __init__(self, model: str, cfg_file: str = None, *args, **kwargs): | def __init__(self, model: str, cfg_file: str = None, *args, **kwargs): | ||||
| if not osp.exists(model): | |||||
| model = snapshot_download(model) | |||||
| model = self.get_or_download_model_dir(model) | |||||
| tf.reset_default_graph() | tf.reset_default_graph() | ||||
| self.model_dir = model | self.model_dir = model | ||||
| @@ -10,7 +10,6 @@ import torch | |||||
| from torch import nn | from torch import nn | ||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.metrics.builder import build_metric | from modelscope.metrics.builder import build_metric | ||||
| from modelscope.models.base import Model, TorchModel | from modelscope.models.base import Model, TorchModel | ||||
| @@ -478,11 +477,7 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||||
| """ | """ | ||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| if os.path.exists(model): | |||||
| model_dir = model if os.path.isdir(model) else os.path.dirname( | |||||
| model) | |||||
| else: | |||||
| model_dir = snapshot_download(model, revision=model_revision) | |||||
| model_dir = self.get_or_download_model_dir(model, model_revision) | |||||
| if cfg_file is None: | if cfg_file is None: | ||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | ||||
| else: | else: | ||||
| @@ -14,7 +14,6 @@ from torch.utils.data import DataLoader, Dataset | |||||
| from torch.utils.data.dataloader import default_collate | from torch.utils.data.dataloader import default_collate | ||||
| from torch.utils.data.distributed import DistributedSampler | from torch.utils.data.distributed import DistributedSampler | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.metrics import build_metric, task_default_metrics | from modelscope.metrics import build_metric, task_default_metrics | ||||
| from modelscope.models.base import Model, TorchModel | from modelscope.models.base import Model, TorchModel | ||||
| @@ -98,12 +97,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self._seed = seed | self._seed = seed | ||||
| set_random_seed(self._seed) | set_random_seed(self._seed) | ||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| if os.path.exists(model): | |||||
| self.model_dir = model if os.path.isdir( | |||||
| model) else os.path.dirname(model) | |||||
| else: | |||||
| self.model_dir = snapshot_download( | |||||
| model, revision=model_revision) | |||||
| self.model_dir = self.get_or_download_model_dir( | |||||
| model, model_revision) | |||||
| if cfg_file is None: | if cfg_file is None: | ||||
| cfg_file = os.path.join(self.model_dir, | cfg_file = os.path.join(self.model_dir, | ||||
| ModelFile.CONFIGURATION) | ModelFile.CONFIGURATION) | ||||
| @@ -44,6 +44,7 @@ class CVTasks(object): | |||||
| image_segmentation = 'image-segmentation' | image_segmentation = 'image-segmentation' | ||||
| semantic_segmentation = 'semantic-segmentation' | semantic_segmentation = 'semantic-segmentation' | ||||
| image_depth_estimation = 'image-depth-estimation' | |||||
| portrait_matting = 'portrait-matting' | portrait_matting = 'portrait-matting' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| @@ -293,6 +294,14 @@ class ModelFile(object): | |||||
| TS_MODEL_FILE = 'model.ts' | TS_MODEL_FILE = 'model.ts' | ||||
| class Invoke(object): | |||||
| KEY = 'invoked_by' | |||||
| PRETRAINED = 'from_pretrained' | |||||
| PIPELINE = 'pipeline' | |||||
| TRAINER = 'trainer' | |||||
| PREPROCESSOR = 'preprocessor' | |||||
| class ConfigFields(object): | class ConfigFields(object): | ||||
| """ First level keyword in configuration file | """ First level keyword in configuration file | ||||
| """ | """ | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import cv2 | import cv2 | ||||
| import matplotlib.pyplot as plt | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| @@ -439,3 +440,11 @@ def show_image_object_detection_auto_result(img_path, | |||||
| if save_path is not None: | if save_path is not None: | ||||
| cv2.imwrite(save_path, img) | cv2.imwrite(save_path, img) | ||||
| return img | return img | ||||
| def depth_to_color(depth): | |||||
| colormap = plt.get_cmap('plasma') | |||||
| depth_color = (colormap( | |||||
| (depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3] | |||||
| depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR) | |||||
| return depth_color | |||||
| @@ -1,4 +1,5 @@ | |||||
| ftfy>=6.0.3 | ftfy>=6.0.3 | ||||
| librosa | |||||
| ofa>=0.0.2 | ofa>=0.0.2 | ||||
| pycocoevalcap>=1.2 | pycocoevalcap>=1.2 | ||||
| pycocotools>=2.0.4 | pycocotools>=2.0.4 | ||||
| @@ -28,7 +28,7 @@ class ExtractiveSummarizationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| result = p(documents=documents) | result = p(documents=documents) | ||||
| return result | return result | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_doc(self): | def test_run_with_doc(self): | ||||
| logger.info( | logger.info( | ||||
| 'Run doc extractive summarization (PoNet) with one document ...') | 'Run doc extractive summarization (PoNet) with one document ...') | ||||
| @@ -37,7 +37,7 @@ class ExtractiveSummarizationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| model_id=self.ponet_doc_model_id, documents=self.sentences) | model_id=self.ponet_doc_model_id, documents=self.sentences) | ||||
| print(result[OutputKeys.TEXT]) | print(result[OutputKeys.TEXT]) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_topic(self): | def test_run_with_topic(self): | ||||
| logger.info( | logger.info( | ||||
| 'Run topic extractive summarization (PoNet) with one document ...') | 'Run topic extractive summarization (PoNet) with one document ...') | ||||
| @@ -0,0 +1,35 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import depth_to_color | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ImageDepthEstimationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.task = 'image-depth-estimation' | |||||
| self.model_id = 'damo/cv_newcrfs_image-depth-estimation_indoor' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_image_depth_estimation(self): | |||||
| input_location = 'data/test/images/image_depth_estimation.jpg' | |||||
| estimator = pipeline(Tasks.image_depth_estimation, model=self.model_id) | |||||
| result = estimator(input_location) | |||||
| depths = result[OutputKeys.DEPTHS] | |||||
| depth_viz = depth_to_color(depths[0].squeeze().cpu().numpy()) | |||||
| cv2.imwrite('result.jpg', depth_viz) | |||||
| print('test_image_depth_estimation DONE') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||