panoptic segmentation 模型接入 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9758389master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a | |||
size 245864 |
@@ -20,6 +20,7 @@ class Models(object): | |||
product_retrieval_embedding = 'product-retrieval-embedding' | |||
body_2d_keypoints = 'body-2d-keypoints' | |||
crowd_counting = 'HRNetCrowdCounting' | |||
panoptic_segmentation = 'swinL-panoptic-segmentation' | |||
image_reid_person = 'passvitb' | |||
video_summarization = 'pgl-video-summarization' | |||
@@ -114,6 +115,7 @@ class Pipelines(object): | |||
tinynas_classification = 'tinynas-classification' | |||
crowd_counting = 'hrnet-crowd-counting' | |||
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||
image_panoptic_segmentation = 'image-panoptic-segmentation' | |||
video_summarization = 'googlenet_pgl_video_summarization' | |||
image_reid_person = 'passvitb-image-reid-person' | |||
@@ -3,8 +3,9 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
cartoon, cmdssl_video_embedding, crowd_counting, face_detection, | |||
face_generation, image_classification, image_color_enhance, | |||
image_colorization, image_denoise, image_instance_segmentation, | |||
image_portrait_enhancement, image_reid_person, | |||
image_to_image_generation, image_to_image_translation, | |||
object_detection, product_retrieval_embedding, | |||
salient_detection, super_resolution, | |||
video_single_object_tracking, video_summarization, virual_tryon) | |||
image_panoptic_segmentation, image_portrait_enhancement, | |||
image_reid_person, image_to_image_generation, | |||
image_to_image_translation, object_detection, | |||
product_retrieval_embedding, salient_detection, | |||
super_resolution, video_single_object_tracking, | |||
video_summarization, virual_tryon) |
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .panseg_model import SwinLPanopticSegmentation | |||
else: | |||
_import_structure = { | |||
'panseg_model': ['SwinLPanopticSegmentation'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,54 @@ | |||
import os.path as osp | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base.base_torch_model import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
@MODELS.register_module( | |||
Tasks.image_segmentation, module_name=Models.panoptic_segmentation) | |||
class SwinLPanopticSegmentation(TorchModel): | |||
def __init__(self, model_dir: str, **kwargs): | |||
"""str -- model file root.""" | |||
super().__init__(model_dir, **kwargs) | |||
from mmcv.runner import load_checkpoint | |||
import mmcv | |||
from mmdet.models import build_detector | |||
config = osp.join(model_dir, 'config.py') | |||
cfg = mmcv.Config.fromfile(config) | |||
if 'pretrained' in cfg.model: | |||
cfg.model.pretrained = None | |||
elif 'init_cfg' in cfg.model.backbone: | |||
cfg.model.backbone.init_cfg = None | |||
# build model | |||
cfg.model.train_cfg = None | |||
self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |||
# load model | |||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
checkpoint = load_checkpoint( | |||
self.model, model_path, map_location='cpu') | |||
self.CLASSES = checkpoint['meta']['CLASSES'] | |||
self.num_classes = len(self.CLASSES) | |||
self.cfg = cfg | |||
def inference(self, data): | |||
"""data is dict,contain img and img_metas,follow with mmdet.""" | |||
with torch.no_grad(): | |||
results = self.model(return_loss=False, rescale=True, **data) | |||
return results | |||
def forward(self, Inputs): | |||
import pdb | |||
pdb.set_trace() | |||
return self.model(**Inputs) |
@@ -23,6 +23,7 @@ if TYPE_CHECKING: | |||
from .image_denoise_pipeline import ImageDenoisePipeline | |||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||
from .image_matting_pipeline import ImageMattingPipeline | |||
from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline | |||
from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | |||
from .image_reid_person_pipeline import ImageReidPersonPipeline | |||
from .image_style_transfer_pipeline import ImageStyleTransferPipeline | |||
@@ -37,6 +38,7 @@ if TYPE_CHECKING: | |||
from .tinynas_classification_pipeline import TinynasClassificationPipeline | |||
from .video_category_pipeline import VideoCategoryPipeline | |||
from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
else: | |||
_import_structure = { | |||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
@@ -59,6 +61,8 @@ else: | |||
'image_instance_segmentation_pipeline': | |||
['ImageInstanceSegmentationPipeline'], | |||
'image_matting_pipeline': ['ImageMattingPipeline'], | |||
'image_panoptic_segmentation_pipeline': | |||
['ImagePanopticSegmentationPipeline'], | |||
'image_portrait_enhancement_pipeline': | |||
['ImagePortraitEnhancementPipeline'], | |||
'image_reid_person_pipeline': ['ImageReidPersonPipeline'], | |||
@@ -0,0 +1,103 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Union | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_segmentation, | |||
module_name=Pipelines.image_panoptic_segmentation) | |||
class ImagePanopticSegmentationPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a image panoptic segmentation pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
logger.info('panoptic segmentation model, pipeline init') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
from mmdet.datasets.pipelines import Compose | |||
from mmcv.parallel import collate, scatter | |||
from mmdet.datasets import replace_ImageToTensor | |||
cfg = self.model.cfg | |||
# build the data pipeline | |||
if isinstance(input, str): | |||
# input is str, file names, pipeline loadimagefromfile | |||
# collect data | |||
data = dict(img_info=dict(filename=input), img_prefix=None) | |||
elif isinstance(input, PIL.Image.Image): | |||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||
img = np.array(input.convert('RGB')) | |||
# collect data | |||
data = dict(img=img) | |||
elif isinstance(input, np.ndarray): | |||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
else: | |||
img = input | |||
img = img[:, :, ::-1] # in rgb order | |||
# collect data | |||
data = dict(img=img) | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |||
test_pipeline = Compose(cfg.data.test.pipeline) | |||
data = test_pipeline(data) | |||
# copy from mmdet_model collect data | |||
data = collate([data], samples_per_gpu=1) | |||
data['img_metas'] = [ | |||
img_metas.data[0] for img_metas in data['img_metas'] | |||
] | |||
data['img'] = [img.data[0] for img in data['img']] | |||
if next(self.model.parameters()).is_cuda: | |||
# scatter to specified GPU | |||
data = scatter(data, [next(self.model.parameters()).device])[0] | |||
return data | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
results = self.model.inference(input) | |||
return results | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
# bz=1, tcguo | |||
pan_results = inputs[0]['pan_results'] | |||
INSTANCE_OFFSET = 1000 | |||
ids = np.unique(pan_results)[::-1] | |||
legal_indices = ids != self.model.num_classes # for VOID label | |||
ids = ids[legal_indices] | |||
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) | |||
segms = (pan_results[None] == ids[:, None, None]) | |||
masks = [it.astype(np.int) for it in segms] | |||
labels_txt = np.array(self.model.CLASSES)[labels].tolist() | |||
outputs = { | |||
OutputKeys.MASKS: masks, | |||
OutputKeys.LABELS: labels_txt, | |||
OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] | |||
} | |||
return outputs |
@@ -134,3 +134,22 @@ def show_video_tracking_result(video_in_path, bboxes, video_save_path): | |||
video_writer.write(frame) | |||
video_writer.release | |||
cap.release() | |||
def panoptic_seg_masks_to_image(masks): | |||
draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) | |||
from mmdet.core.visualization.palette import get_palette | |||
mask_palette = get_palette('coco', 133) | |||
from mmdet.core.visualization.image import _get_bias_color | |||
taken_colors = set([0, 0, 0]) | |||
for i, mask in enumerate(masks): | |||
color_mask = mask_palette[i] | |||
while tuple(color_mask) in taken_colors: | |||
color_mask = _get_bias_color(color_mask) | |||
taken_colors.add(tuple(color_mask)) | |||
mask = mask.astype(bool) | |||
draw_img[mask] = color_mask | |||
return draw_img |
@@ -0,0 +1,40 @@ | |||
import unittest | |||
import cv2 | |||
import PIL | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.cv.image_utils import panoptic_seg_masks_to_image | |||
from modelscope.utils.test_utils import test_level | |||
class ImagePanopticSegmentationTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_image_panoptic_segmentation(self): | |||
input_location = 'data/test/images/image_panoptic_segmentation.jpg' | |||
model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan' | |||
pan_segmentor = pipeline(Tasks.image_segmentation, model=model_id) | |||
result = pan_segmentor(input_location) | |||
draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
cv2.imwrite('result.jpg', draw_img) | |||
print('print test_image_panoptic_segmentation return success') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_image_panoptic_segmentation_from_PIL(self): | |||
input_location = 'data/test/images/image_panoptic_segmentation.jpg' | |||
model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan' | |||
pan_segmentor = pipeline(Tasks.image_segmentation, model=model_id) | |||
PIL_array = PIL.Image.open(input_location) | |||
result = pan_segmentor(PIL_array) | |||
draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
cv2.imwrite('result.jpg', draw_img) | |||
print('print test_image_panoptic_segmentation from PIL return success') | |||
if __name__ == '__main__': | |||
unittest.main() |