@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:8e9ab135da7eacabdeeeee11ba4b7bcdd1bfac128cf92a9de9c79f984060ae1e | |||||
size 259865 |
@@ -11,6 +11,7 @@ class Models(object): | |||||
""" | """ | ||||
# vision models | # vision models | ||||
csrnet = 'csrnet' | csrnet = 'csrnet' | ||||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | |||||
# nlp models | # nlp models | ||||
bert = 'bert' | bert = 'bert' | ||||
@@ -67,6 +68,7 @@ class Pipelines(object): | |||||
image_super_resolution = 'rrdb-image-super-resolution' | image_super_resolution = 'rrdb-image-super-resolution' | ||||
face_image_generation = 'gan-face-image-generation' | face_image_generation = 'gan-face-image-generation' | ||||
style_transfer = 'AAMS-style-transfer' | style_transfer = 'AAMS-style-transfer' | ||||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||||
# nlp tasks | # nlp tasks | ||||
sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
@@ -124,6 +126,7 @@ class Preprocessors(object): | |||||
# cv preprocessor | # cv preprocessor | ||||
load_image = 'load-image' | load_image = 'load-image' | ||||
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | ||||
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | |||||
# nlp preprocessor | # nlp preprocessor | ||||
sen_sim_tokenizer = 'sen-sim-tokenizer' | sen_sim_tokenizer = 'sen-sim-tokenizer' | ||||
@@ -157,6 +160,8 @@ class Metrics(object): | |||||
# accuracy | # accuracy | ||||
accuracy = 'accuracy' | accuracy = 'accuracy' | ||||
# metric for image instance segmentation task | |||||
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric' | |||||
# metrics for sequence classification task | # metrics for sequence classification task | ||||
seq_cls_metric = 'seq_cls_metric' | seq_cls_metric = 'seq_cls_metric' | ||||
# metrics for token-classification task | # metrics for token-classification task | ||||
@@ -1,5 +1,7 @@ | |||||
from .base import Metric | from .base import Metric | ||||
from .builder import METRICS, build_metric, task_default_metrics | from .builder import METRICS, build_metric, task_default_metrics | ||||
from .image_color_enhance_metric import ImageColorEnhanceMetric | from .image_color_enhance_metric import ImageColorEnhanceMetric | ||||
from .image_instance_segmentation_metric import \ | |||||
ImageInstanceSegmentationCOCOMetric | |||||
from .sequence_classification_metric import SequenceClassificationMetric | from .sequence_classification_metric import SequenceClassificationMetric | ||||
from .text_generation_metric import TextGenerationMetric | from .text_generation_metric import TextGenerationMetric |
@@ -18,6 +18,7 @@ class MetricKeys(object): | |||||
task_default_metrics = { | task_default_metrics = { | ||||
Tasks.image_segmentation: [Metrics.image_ins_seg_coco_metric], | |||||
Tasks.sentence_similarity: [Metrics.seq_cls_metric], | Tasks.sentence_similarity: [Metrics.seq_cls_metric], | ||||
Tasks.sentiment_classification: [Metrics.seq_cls_metric], | Tasks.sentiment_classification: [Metrics.seq_cls_metric], | ||||
Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
@@ -0,0 +1,312 @@ | |||||
import os.path as osp | |||||
import tempfile | |||||
from collections import OrderedDict | |||||
from typing import Any, Dict | |||||
import numpy as np | |||||
import pycocotools.mask as mask_util | |||||
from pycocotools.coco import COCO | |||||
from pycocotools.cocoeval import COCOeval | |||||
from modelscope.fileio import dump, load | |||||
from modelscope.metainfo import Metrics | |||||
from modelscope.metrics import METRICS, Metric | |||||
from modelscope.utils.registry import default_group | |||||
@METRICS.register_module( | |||||
group_key=default_group, module_name=Metrics.image_ins_seg_coco_metric) | |||||
class ImageInstanceSegmentationCOCOMetric(Metric): | |||||
"""The metric computation class for COCO-style image instance segmentation. | |||||
""" | |||||
def __init__(self): | |||||
self.ann_file = None | |||||
self.classes = None | |||||
self.metrics = ['bbox', 'segm'] | |||||
self.proposal_nums = (100, 300, 1000) | |||||
self.iou_thrs = np.linspace( | |||||
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) | |||||
self.results = [] | |||||
def add(self, outputs: Dict[str, Any], inputs: Dict[str, Any]): | |||||
result = outputs['eval_result'] | |||||
# encode mask results | |||||
if isinstance(result[0], tuple): | |||||
result = [(bbox_results, encode_mask_results(mask_results)) | |||||
for bbox_results, mask_results in result] | |||||
self.results.extend(result) | |||||
if self.ann_file is None: | |||||
self.ann_file = outputs['img_metas'][0]['ann_file'] | |||||
self.classes = outputs['img_metas'][0]['classes'] | |||||
def evaluate(self): | |||||
cocoGt = COCO(self.ann_file) | |||||
self.cat_ids = cocoGt.getCatIds(catNms=self.classes) | |||||
self.img_ids = cocoGt.getImgIds() | |||||
result_files, tmp_dir = self.format_results(self.results, self.img_ids) | |||||
eval_results = OrderedDict() | |||||
for metric in self.metrics: | |||||
iou_type = metric | |||||
if metric not in result_files: | |||||
raise KeyError(f'{metric} is not in results') | |||||
try: | |||||
predictions = load(result_files[metric]) | |||||
if iou_type == 'segm': | |||||
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa | |||||
# When evaluating mask AP, if the results contain bbox, | |||||
# cocoapi will use the box area instead of the mask area | |||||
# for calculating the instance area. Though the overall AP | |||||
# is not affected, this leads to different | |||||
# small/medium/large mask AP results. | |||||
for x in predictions: | |||||
x.pop('bbox') | |||||
cocoDt = cocoGt.loadRes(predictions) | |||||
except IndexError: | |||||
print('The testing results of the whole dataset is empty.') | |||||
break | |||||
cocoEval = COCOeval(cocoGt, cocoDt, iou_type) | |||||
cocoEval.params.catIds = self.cat_ids | |||||
cocoEval.params.imgIds = self.img_ids | |||||
cocoEval.params.maxDets = list(self.proposal_nums) | |||||
cocoEval.params.iouThrs = self.iou_thrs | |||||
# mapping of cocoEval.stats | |||||
coco_metric_names = { | |||||
'mAP': 0, | |||||
'mAP_50': 1, | |||||
'mAP_75': 2, | |||||
'mAP_s': 3, | |||||
'mAP_m': 4, | |||||
'mAP_l': 5, | |||||
'AR@100': 6, | |||||
'AR@300': 7, | |||||
'AR@1000': 8, | |||||
'AR_s@1000': 9, | |||||
'AR_m@1000': 10, | |||||
'AR_l@1000': 11 | |||||
} | |||||
cocoEval.evaluate() | |||||
cocoEval.accumulate() | |||||
cocoEval.summarize() | |||||
metric_items = [ | |||||
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' | |||||
] | |||||
for metric_item in metric_items: | |||||
key = f'{metric}_{metric_item}' | |||||
val = float( | |||||
f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}') | |||||
eval_results[key] = val | |||||
ap = cocoEval.stats[:6] | |||||
eval_results[f'{metric}_mAP_copypaste'] = ( | |||||
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' | |||||
f'{ap[4]:.3f} {ap[5]:.3f}') | |||||
if tmp_dir is not None: | |||||
tmp_dir.cleanup() | |||||
return eval_results | |||||
def format_results(self, results, img_ids, jsonfile_prefix=None, **kwargs): | |||||
"""Format the results to json (standard format for COCO evaluation). | |||||
Args: | |||||
results (list[tuple | numpy.ndarray]): Testing results of the | |||||
dataset. | |||||
data_infos(list[tuple | numpy.ndarray]): data information | |||||
jsonfile_prefix (str | None): The prefix of json files. It includes | |||||
the file path and the prefix of filename, e.g., "a/b/prefix". | |||||
If not specified, a temp file will be created. Default: None. | |||||
Returns: | |||||
tuple: (result_files, tmp_dir), result_files is a dict containing \ | |||||
the json filepaths, tmp_dir is the temporal directory created \ | |||||
for saving json files when jsonfile_prefix is not specified. | |||||
""" | |||||
assert isinstance(results, list), 'results must be a list' | |||||
assert len(results) == len(img_ids), ( | |||||
'The length of results is not equal to the dataset len: {} != {}'. | |||||
format(len(results), len(img_ids))) | |||||
if jsonfile_prefix is None: | |||||
tmp_dir = tempfile.TemporaryDirectory() | |||||
jsonfile_prefix = osp.join(tmp_dir.name, 'results') | |||||
else: | |||||
tmp_dir = None | |||||
result_files = self.results2json(results, jsonfile_prefix) | |||||
return result_files, tmp_dir | |||||
def xyxy2xywh(self, bbox): | |||||
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO | |||||
evaluation. | |||||
Args: | |||||
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in | |||||
``xyxy`` order. | |||||
Returns: | |||||
list[float]: The converted bounding boxes, in ``xywh`` order. | |||||
""" | |||||
_bbox = bbox.tolist() | |||||
return [ | |||||
_bbox[0], | |||||
_bbox[1], | |||||
_bbox[2] - _bbox[0], | |||||
_bbox[3] - _bbox[1], | |||||
] | |||||
def _proposal2json(self, results): | |||||
"""Convert proposal results to COCO json style.""" | |||||
json_results = [] | |||||
for idx in range(len(self.img_ids)): | |||||
img_id = self.img_ids[idx] | |||||
bboxes = results[idx] | |||||
for i in range(bboxes.shape[0]): | |||||
data = dict() | |||||
data['image_id'] = img_id | |||||
data['bbox'] = self.xyxy2xywh(bboxes[i]) | |||||
data['score'] = float(bboxes[i][4]) | |||||
data['category_id'] = 1 | |||||
json_results.append(data) | |||||
return json_results | |||||
def _det2json(self, results): | |||||
"""Convert detection results to COCO json style.""" | |||||
json_results = [] | |||||
for idx in range(len(self.img_ids)): | |||||
img_id = self.img_ids[idx] | |||||
result = results[idx] | |||||
for label in range(len(result)): | |||||
# Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO) | |||||
# (assuming the num class of input dataset is no more than 80). | |||||
# Recommended manually set `num_classes=${your test dataset class num}` in the | |||||
# configuration.json in practice. | |||||
if label >= len(self.classes): | |||||
break | |||||
bboxes = result[label] | |||||
for i in range(bboxes.shape[0]): | |||||
data = dict() | |||||
data['image_id'] = img_id | |||||
data['bbox'] = self.xyxy2xywh(bboxes[i]) | |||||
data['score'] = float(bboxes[i][4]) | |||||
data['category_id'] = self.cat_ids[label] | |||||
json_results.append(data) | |||||
return json_results | |||||
def _segm2json(self, results): | |||||
"""Convert instance segmentation results to COCO json style.""" | |||||
bbox_json_results = [] | |||||
segm_json_results = [] | |||||
for idx in range(len(self.img_ids)): | |||||
img_id = self.img_ids[idx] | |||||
det, seg = results[idx] | |||||
for label in range(len(det)): | |||||
# Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO) | |||||
# (assuming the num class of input dataset is no more than 80). | |||||
# Recommended manually set `num_classes=${your test dataset class num}` in the | |||||
# configuration.json in practice. | |||||
if label >= len(self.classes): | |||||
break | |||||
# bbox results | |||||
bboxes = det[label] | |||||
for i in range(bboxes.shape[0]): | |||||
data = dict() | |||||
data['image_id'] = img_id | |||||
data['bbox'] = self.xyxy2xywh(bboxes[i]) | |||||
data['score'] = float(bboxes[i][4]) | |||||
data['category_id'] = self.cat_ids[label] | |||||
bbox_json_results.append(data) | |||||
# segm results | |||||
# some detectors use different scores for bbox and mask | |||||
if isinstance(seg, tuple): | |||||
segms = seg[0][label] | |||||
mask_score = seg[1][label] | |||||
else: | |||||
segms = seg[label] | |||||
mask_score = [bbox[4] for bbox in bboxes] | |||||
for i in range(bboxes.shape[0]): | |||||
data = dict() | |||||
data['image_id'] = img_id | |||||
data['bbox'] = self.xyxy2xywh(bboxes[i]) | |||||
data['score'] = float(mask_score[i]) | |||||
data['category_id'] = self.cat_ids[label] | |||||
if isinstance(segms[i]['counts'], bytes): | |||||
segms[i]['counts'] = segms[i]['counts'].decode() | |||||
data['segmentation'] = segms[i] | |||||
segm_json_results.append(data) | |||||
return bbox_json_results, segm_json_results | |||||
def results2json(self, results, outfile_prefix): | |||||
"""Dump the detection results to a COCO style json file. | |||||
There are 3 types of results: proposals, bbox predictions, mask | |||||
predictions, and they have different data types. This method will | |||||
automatically recognize the type, and dump them to json files. | |||||
Args: | |||||
results (list[list | tuple | ndarray]): Testing results of the | |||||
dataset. | |||||
outfile_prefix (str): The filename prefix of the json files. If the | |||||
prefix is "somepath/xxx", the json files will be named | |||||
"somepath/xxx.bbox.json", "somepath/xxx.segm.json", | |||||
"somepath/xxx.proposal.json". | |||||
Returns: | |||||
dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \ | |||||
values are corresponding filenames. | |||||
""" | |||||
result_files = dict() | |||||
if isinstance(results[0], list): | |||||
json_results = self._det2json(results) | |||||
result_files['bbox'] = f'{outfile_prefix}.bbox.json' | |||||
result_files['proposal'] = f'{outfile_prefix}.bbox.json' | |||||
dump(json_results, result_files['bbox']) | |||||
elif isinstance(results[0], tuple): | |||||
json_results = self._segm2json(results) | |||||
result_files['bbox'] = f'{outfile_prefix}.bbox.json' | |||||
result_files['proposal'] = f'{outfile_prefix}.bbox.json' | |||||
result_files['segm'] = f'{outfile_prefix}.segm.json' | |||||
dump(json_results[0], result_files['bbox']) | |||||
dump(json_results[1], result_files['segm']) | |||||
elif isinstance(results[0], np.ndarray): | |||||
json_results = self._proposal2json(results) | |||||
result_files['proposal'] = f'{outfile_prefix}.proposal.json' | |||||
dump(json_results, result_files['proposal']) | |||||
else: | |||||
raise TypeError('invalid type of results') | |||||
return result_files | |||||
def encode_mask_results(mask_results): | |||||
"""Encode bitmap mask to RLE code. | |||||
Args: | |||||
mask_results (list | tuple[list]): bitmap mask results. | |||||
In mask scoring rcnn, mask_results is a tuple of (segm_results, | |||||
segm_cls_score). | |||||
Returns: | |||||
list | tuple: RLE encoded mask. | |||||
""" | |||||
if isinstance(mask_results, tuple): # mask scoring | |||||
cls_segms, cls_mask_scores = mask_results | |||||
else: | |||||
cls_segms = mask_results | |||||
num_classes = len(cls_segms) | |||||
encoded_mask_results = [[] for _ in range(num_classes)] | |||||
for i in range(len(cls_segms)): | |||||
for cls_segm in cls_segms[i]: | |||||
encoded_mask_results[i].append( | |||||
mask_util.encode( | |||||
np.array( | |||||
cls_segm[:, :, np.newaxis], order='F', | |||||
dtype='uint8'))[0]) # encoded with RLE | |||||
if isinstance(mask_results, tuple): | |||||
return encoded_mask_results, cls_mask_scores | |||||
else: | |||||
return encoded_mask_results |
@@ -0,0 +1,2 @@ | |||||
from .cascade_mask_rcnn_swin import CascadeMaskRCNNSwin | |||||
from .model import CascadeMaskRCNNSwinModel |
@@ -0,0 +1 @@ | |||||
from .swin_transformer import SwinTransformer |
@@ -0,0 +1,694 @@ | |||||
# Modified from: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import torch.utils.checkpoint as checkpoint | |||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |||||
class Mlp(nn.Module): | |||||
""" Multilayer perceptron.""" | |||||
def __init__(self, | |||||
in_features, | |||||
hidden_features=None, | |||||
out_features=None, | |||||
act_layer=nn.GELU, | |||||
drop=0.): | |||||
super().__init__() | |||||
out_features = out_features or in_features | |||||
hidden_features = hidden_features or in_features | |||||
self.fc1 = nn.Linear(in_features, hidden_features) | |||||
self.act = act_layer() | |||||
self.fc2 = nn.Linear(hidden_features, out_features) | |||||
self.drop = nn.Dropout(drop) | |||||
def forward(self, x): | |||||
x = self.fc1(x) | |||||
x = self.act(x) | |||||
x = self.drop(x) | |||||
x = self.fc2(x) | |||||
x = self.drop(x) | |||||
return x | |||||
def window_partition(x, window_size): | |||||
""" | |||||
Args: | |||||
x: (B, H, W, C) | |||||
window_size (int): window size | |||||
Returns: | |||||
windows: (num_windows*B, window_size, window_size, C) | |||||
""" | |||||
B, H, W, C = x.shape | |||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, | |||||
C) | |||||
windows = x.permute(0, 1, 3, 2, 4, | |||||
5).contiguous().view(-1, window_size, window_size, C) | |||||
return windows | |||||
def window_reverse(windows, window_size, H, W): | |||||
""" | |||||
Args: | |||||
windows: (num_windows*B, window_size, window_size, C) | |||||
window_size (int): Window size | |||||
H (int): Height of image | |||||
W (int): Width of image | |||||
Returns: | |||||
x: (B, H, W, C) | |||||
""" | |||||
B = int(windows.shape[0] / (H * W / window_size / window_size)) | |||||
x = windows.view(B, H // window_size, W // window_size, window_size, | |||||
window_size, -1) | |||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |||||
return x | |||||
class WindowAttention(nn.Module): | |||||
""" Window based multi-head self attention (W-MSA) module with relative position bias. | |||||
It supports both of shifted and non-shifted window. | |||||
Args: | |||||
dim (int): Number of input channels. | |||||
window_size (tuple[int]): The height and width of the window. | |||||
num_heads (int): Number of attention heads. | |||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |||||
""" | |||||
def __init__(self, | |||||
dim, | |||||
window_size, | |||||
num_heads, | |||||
qkv_bias=True, | |||||
qk_scale=None, | |||||
attn_drop=0., | |||||
proj_drop=0.): | |||||
super().__init__() | |||||
self.dim = dim | |||||
self.window_size = window_size # Wh, Ww | |||||
self.num_heads = num_heads | |||||
head_dim = dim // num_heads | |||||
self.scale = qk_scale or head_dim**-0.5 | |||||
# define a parameter table of relative position bias | |||||
self.relative_position_bias_table = nn.Parameter( | |||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), | |||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||||
# get pair-wise relative position index for each token inside the window | |||||
coords_h = torch.arange(self.window_size[0]) | |||||
coords_w = torch.arange(self.window_size[1]) | |||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||||
relative_coords = coords_flatten[:, :, | |||||
None] - coords_flatten[:, | |||||
None, :] # 2, Wh*Ww, Wh*Ww | |||||
relative_coords = relative_coords.permute( | |||||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||||
relative_coords[:, :, | |||||
0] += self.window_size[0] - 1 # shift to start from 0 | |||||
relative_coords[:, :, 1] += self.window_size[1] - 1 | |||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |||||
self.register_buffer('relative_position_index', | |||||
relative_position_index) | |||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||||
self.attn_drop = nn.Dropout(attn_drop) | |||||
self.proj = nn.Linear(dim, dim) | |||||
self.proj_drop = nn.Dropout(proj_drop) | |||||
trunc_normal_(self.relative_position_bias_table, std=.02) | |||||
self.softmax = nn.Softmax(dim=-1) | |||||
def forward(self, x, mask=None): | |||||
""" Forward function. | |||||
Args: | |||||
x: input features with shape of (num_windows*B, N, C) | |||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |||||
""" | |||||
B_, N, C = x.shape | |||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, | |||||
C // self.num_heads).permute(2, 0, 3, 1, 4) | |||||
q, k, v = qkv[0], qkv[1], qkv[ | |||||
2] # make torchscript happy (cannot use tensor as tuple) | |||||
q = q * self.scale | |||||
attn = (q @ k.transpose(-2, -1)) | |||||
relative_position_bias = self.relative_position_bias_table[ | |||||
self.relative_position_index.view(-1)].view( | |||||
self.window_size[0] * self.window_size[1], | |||||
self.window_size[0] * self.window_size[1], | |||||
-1) # Wh*Ww,Wh*Ww,nH | |||||
relative_position_bias = relative_position_bias.permute( | |||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||||
attn = attn + relative_position_bias.unsqueeze(0) | |||||
if mask is not None: | |||||
nW = mask.shape[0] | |||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, | |||||
N) + mask.unsqueeze(1).unsqueeze(0) | |||||
attn = attn.view(-1, self.num_heads, N, N) | |||||
attn = self.softmax(attn) | |||||
else: | |||||
attn = self.softmax(attn) | |||||
attn = self.attn_drop(attn) | |||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |||||
x = self.proj(x) | |||||
x = self.proj_drop(x) | |||||
return x | |||||
class SwinTransformerBlock(nn.Module): | |||||
""" Swin Transformer Block. | |||||
Args: | |||||
dim (int): Number of input channels. | |||||
num_heads (int): Number of attention heads. | |||||
window_size (int): Window size. | |||||
shift_size (int): Shift size for SW-MSA. | |||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||||
drop (float, optional): Dropout rate. Default: 0.0 | |||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||||
""" | |||||
def __init__(self, | |||||
dim, | |||||
num_heads, | |||||
window_size=7, | |||||
shift_size=0, | |||||
mlp_ratio=4., | |||||
qkv_bias=True, | |||||
qk_scale=None, | |||||
drop=0., | |||||
attn_drop=0., | |||||
drop_path=0., | |||||
act_layer=nn.GELU, | |||||
norm_layer=nn.LayerNorm): | |||||
super().__init__() | |||||
self.dim = dim | |||||
self.num_heads = num_heads | |||||
self.window_size = window_size | |||||
self.shift_size = shift_size | |||||
self.mlp_ratio = mlp_ratio | |||||
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' | |||||
self.norm1 = norm_layer(dim) | |||||
self.attn = WindowAttention( | |||||
dim, | |||||
window_size=to_2tuple(self.window_size), | |||||
num_heads=num_heads, | |||||
qkv_bias=qkv_bias, | |||||
qk_scale=qk_scale, | |||||
attn_drop=attn_drop, | |||||
proj_drop=drop) | |||||
self.drop_path = DropPath( | |||||
drop_path) if drop_path > 0. else nn.Identity() | |||||
self.norm2 = norm_layer(dim) | |||||
mlp_hidden_dim = int(dim * mlp_ratio) | |||||
self.mlp = Mlp( | |||||
in_features=dim, | |||||
hidden_features=mlp_hidden_dim, | |||||
act_layer=act_layer, | |||||
drop=drop) | |||||
self.H = None | |||||
self.W = None | |||||
def forward(self, x, mask_matrix): | |||||
""" Forward function. | |||||
Args: | |||||
x: Input feature, tensor size (B, H*W, C). | |||||
H, W: Spatial resolution of the input feature. | |||||
mask_matrix: Attention mask for cyclic shift. | |||||
""" | |||||
B, L, C = x.shape | |||||
H, W = self.H, self.W | |||||
assert L == H * W, 'input feature has wrong size' | |||||
shortcut = x | |||||
x = self.norm1(x) | |||||
x = x.view(B, H, W, C) | |||||
# pad feature maps to multiples of window size | |||||
pad_l = pad_t = 0 | |||||
pad_r = (self.window_size - W % self.window_size) % self.window_size | |||||
pad_b = (self.window_size - H % self.window_size) % self.window_size | |||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||||
_, Hp, Wp, _ = x.shape | |||||
# cyclic shift | |||||
if self.shift_size > 0: | |||||
shifted_x = torch.roll( | |||||
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |||||
attn_mask = mask_matrix | |||||
else: | |||||
shifted_x = x | |||||
attn_mask = None | |||||
# partition windows | |||||
x_windows = window_partition( | |||||
shifted_x, self.window_size) # nW*B, window_size, window_size, C | |||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, | |||||
C) # nW*B, window_size*window_size, C | |||||
# W-MSA/SW-MSA | |||||
attn_windows = self.attn( | |||||
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C | |||||
# merge windows | |||||
attn_windows = attn_windows.view(-1, self.window_size, | |||||
self.window_size, C) | |||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, | |||||
Wp) # B H' W' C | |||||
# reverse cyclic shift | |||||
if self.shift_size > 0: | |||||
x = torch.roll( | |||||
shifted_x, | |||||
shifts=(self.shift_size, self.shift_size), | |||||
dims=(1, 2)) | |||||
else: | |||||
x = shifted_x | |||||
if pad_r > 0 or pad_b > 0: | |||||
x = x[:, :H, :W, :].contiguous() | |||||
x = x.view(B, H * W, C) | |||||
# FFN | |||||
x = shortcut + self.drop_path(x) | |||||
x = x + self.drop_path(self.mlp(self.norm2(x))) | |||||
return x | |||||
class PatchMerging(nn.Module): | |||||
""" Patch Merging Layer | |||||
Args: | |||||
dim (int): Number of input channels. | |||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||||
""" | |||||
def __init__(self, dim, norm_layer=nn.LayerNorm): | |||||
super().__init__() | |||||
self.dim = dim | |||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) | |||||
self.norm = norm_layer(4 * dim) | |||||
def forward(self, x, H, W): | |||||
""" Forward function. | |||||
Args: | |||||
x: Input feature, tensor size (B, H*W, C). | |||||
H, W: Spatial resolution of the input feature. | |||||
""" | |||||
B, L, C = x.shape | |||||
assert L == H * W, 'input feature has wrong size' | |||||
x = x.view(B, H, W, C) | |||||
# padding | |||||
pad_input = (H % 2 == 1) or (W % 2 == 1) | |||||
if pad_input: | |||||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) | |||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C | |||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C | |||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C | |||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C | |||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C | |||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C | |||||
x = self.norm(x) | |||||
x = self.reduction(x) | |||||
return x | |||||
class BasicLayer(nn.Module): | |||||
""" A basic Swin Transformer layer for one stage. | |||||
Args: | |||||
dim (int): Number of feature channels | |||||
depth (int): Depths of this stage. | |||||
num_heads (int): Number of attention head. | |||||
window_size (int): Local window size. Default: 7. | |||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||||
drop (float, optional): Dropout rate. Default: 0.0 | |||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||||
""" | |||||
def __init__(self, | |||||
dim, | |||||
depth, | |||||
num_heads, | |||||
window_size=7, | |||||
mlp_ratio=4., | |||||
qkv_bias=True, | |||||
qk_scale=None, | |||||
drop=0., | |||||
attn_drop=0., | |||||
drop_path=0., | |||||
norm_layer=nn.LayerNorm, | |||||
downsample=None, | |||||
use_checkpoint=False): | |||||
super().__init__() | |||||
self.window_size = window_size | |||||
self.shift_size = window_size // 2 | |||||
self.depth = depth | |||||
self.use_checkpoint = use_checkpoint | |||||
# build blocks | |||||
self.blocks = nn.ModuleList([ | |||||
SwinTransformerBlock( | |||||
dim=dim, | |||||
num_heads=num_heads, | |||||
window_size=window_size, | |||||
shift_size=0 if (i % 2 == 0) else window_size // 2, | |||||
mlp_ratio=mlp_ratio, | |||||
qkv_bias=qkv_bias, | |||||
qk_scale=qk_scale, | |||||
drop=drop, | |||||
attn_drop=attn_drop, | |||||
drop_path=drop_path[i] | |||||
if isinstance(drop_path, list) else drop_path, | |||||
norm_layer=norm_layer) for i in range(depth) | |||||
]) | |||||
# patch merging layer | |||||
if downsample is not None: | |||||
self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |||||
else: | |||||
self.downsample = None | |||||
def forward(self, x, H, W): | |||||
""" Forward function. | |||||
Args: | |||||
x: Input feature, tensor size (B, H*W, C). | |||||
H, W: Spatial resolution of the input feature. | |||||
""" | |||||
# calculate attention mask for SW-MSA | |||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size | |||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size | |||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 | |||||
h_slices = (slice(0, -self.window_size), | |||||
slice(-self.window_size, | |||||
-self.shift_size), slice(-self.shift_size, None)) | |||||
w_slices = (slice(0, -self.window_size), | |||||
slice(-self.window_size, | |||||
-self.shift_size), slice(-self.shift_size, None)) | |||||
cnt = 0 | |||||
for h in h_slices: | |||||
for w in w_slices: | |||||
img_mask[:, h, w, :] = cnt | |||||
cnt += 1 | |||||
mask_windows = window_partition( | |||||
img_mask, self.window_size) # nW, window_size, window_size, 1 | |||||
mask_windows = mask_windows.view(-1, | |||||
self.window_size * self.window_size) | |||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, | |||||
float(-100.0)).masked_fill( | |||||
attn_mask == 0, float(0.0)) | |||||
for blk in self.blocks: | |||||
blk.H, blk.W = H, W | |||||
if self.use_checkpoint: | |||||
x = checkpoint.checkpoint(blk, x, attn_mask) | |||||
else: | |||||
x = blk(x, attn_mask) | |||||
if self.downsample is not None: | |||||
x_down = self.downsample(x, H, W) | |||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2 | |||||
return x, H, W, x_down, Wh, Ww | |||||
else: | |||||
return x, H, W, x, H, W | |||||
class PatchEmbed(nn.Module): | |||||
""" Image to Patch Embedding | |||||
Args: | |||||
patch_size (int): Patch token size. Default: 4. | |||||
in_chans (int): Number of input image channels. Default: 3. | |||||
embed_dim (int): Number of linear projection output channels. Default: 96. | |||||
norm_layer (nn.Module, optional): Normalization layer. Default: None | |||||
""" | |||||
def __init__(self, | |||||
patch_size=4, | |||||
in_chans=3, | |||||
embed_dim=96, | |||||
norm_layer=None): | |||||
super().__init__() | |||||
patch_size = to_2tuple(patch_size) | |||||
self.patch_size = patch_size | |||||
self.in_chans = in_chans | |||||
self.embed_dim = embed_dim | |||||
self.proj = nn.Conv2d( | |||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||||
if norm_layer is not None: | |||||
self.norm = norm_layer(embed_dim) | |||||
else: | |||||
self.norm = None | |||||
def forward(self, x): | |||||
"""Forward function.""" | |||||
# padding | |||||
_, _, H, W = x.size() | |||||
if W % self.patch_size[1] != 0: | |||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) | |||||
if H % self.patch_size[0] != 0: | |||||
x = F.pad(x, | |||||
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) | |||||
x = self.proj(x) # B C Wh Ww | |||||
if self.norm is not None: | |||||
Wh, Ww = x.size(2), x.size(3) | |||||
x = x.flatten(2).transpose(1, 2) | |||||
x = self.norm(x) | |||||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) | |||||
return x | |||||
class SwinTransformer(nn.Module): | |||||
""" Swin Transformer backbone. | |||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - | |||||
https://arxiv.org/pdf/2103.14030 | |||||
Inspiration from | |||||
https://github.com/SwinTransformer/Swin-Transformer-Object-Detection | |||||
Args: | |||||
pretrain_img_size (int): Input image size for training the pretrained model, | |||||
used in absolute postion embedding. Default 224. | |||||
patch_size (int | tuple(int)): Patch size. Default: 4. | |||||
in_chans (int): Number of input image channels. Default: 3. | |||||
embed_dim (int): Number of linear projection output channels. Default: 96. | |||||
depths (tuple[int]): Depths of each Swin Transformer stage. | |||||
num_heads (tuple[int]): Number of attention head of each stage. | |||||
window_size (int): Window size. Default: 7. | |||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True | |||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. | |||||
drop_rate (float): Dropout rate. | |||||
attn_drop_rate (float): Attention dropout rate. Default: 0. | |||||
drop_path_rate (float): Stochastic depth rate. Default: 0.2. | |||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. | |||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. | |||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True. | |||||
out_indices (Sequence[int]): Output from which stages. | |||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |||||
-1 means not freezing any parameters. | |||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||||
""" | |||||
def __init__(self, | |||||
pretrain_img_size=224, | |||||
patch_size=4, | |||||
in_chans=3, | |||||
embed_dim=96, | |||||
depths=[2, 2, 6, 2], | |||||
num_heads=[3, 6, 12, 24], | |||||
window_size=7, | |||||
mlp_ratio=4., | |||||
qkv_bias=True, | |||||
qk_scale=None, | |||||
drop_rate=0., | |||||
attn_drop_rate=0., | |||||
drop_path_rate=0.2, | |||||
norm_layer=nn.LayerNorm, | |||||
ape=False, | |||||
patch_norm=True, | |||||
out_indices=(0, 1, 2, 3), | |||||
frozen_stages=-1, | |||||
use_checkpoint=False): | |||||
super().__init__() | |||||
self.pretrain_img_size = pretrain_img_size | |||||
self.num_layers = len(depths) | |||||
self.embed_dim = embed_dim | |||||
self.ape = ape | |||||
self.patch_norm = patch_norm | |||||
self.out_indices = out_indices | |||||
self.frozen_stages = frozen_stages | |||||
# split image into non-overlapping patches | |||||
self.patch_embed = PatchEmbed( | |||||
patch_size=patch_size, | |||||
in_chans=in_chans, | |||||
embed_dim=embed_dim, | |||||
norm_layer=norm_layer if self.patch_norm else None) | |||||
# absolute position embedding | |||||
if self.ape: | |||||
pretrain_img_size = to_2tuple(pretrain_img_size) | |||||
patch_size = to_2tuple(patch_size) | |||||
patches_resolution = [ | |||||
pretrain_img_size[0] // patch_size[0], | |||||
pretrain_img_size[1] // patch_size[1] | |||||
] | |||||
self.absolute_pos_embed = nn.Parameter( | |||||
torch.zeros(1, embed_dim, patches_resolution[0], | |||||
patches_resolution[1])) | |||||
trunc_normal_(self.absolute_pos_embed, std=.02) | |||||
self.pos_drop = nn.Dropout(p=drop_rate) | |||||
# stochastic depth | |||||
dpr = [ | |||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |||||
] # stochastic depth decay rule | |||||
# build layers | |||||
self.layers = nn.ModuleList() | |||||
for i_layer in range(self.num_layers): | |||||
layer = BasicLayer( | |||||
dim=int(embed_dim * 2**i_layer), | |||||
depth=depths[i_layer], | |||||
num_heads=num_heads[i_layer], | |||||
window_size=window_size, | |||||
mlp_ratio=mlp_ratio, | |||||
qkv_bias=qkv_bias, | |||||
qk_scale=qk_scale, | |||||
drop=drop_rate, | |||||
attn_drop=attn_drop_rate, | |||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |||||
norm_layer=norm_layer, | |||||
downsample=PatchMerging if | |||||
(i_layer < self.num_layers - 1) else None, | |||||
use_checkpoint=use_checkpoint) | |||||
self.layers.append(layer) | |||||
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] | |||||
self.num_features = num_features | |||||
# add a norm layer for each output | |||||
for i_layer in out_indices: | |||||
layer = norm_layer(num_features[i_layer]) | |||||
layer_name = f'norm{i_layer}' | |||||
self.add_module(layer_name, layer) | |||||
self._freeze_stages() | |||||
def _freeze_stages(self): | |||||
if self.frozen_stages >= 0: | |||||
self.patch_embed.eval() | |||||
for param in self.patch_embed.parameters(): | |||||
param.requires_grad = False | |||||
if self.frozen_stages >= 1 and self.ape: | |||||
self.absolute_pos_embed.requires_grad = False | |||||
if self.frozen_stages >= 2: | |||||
self.pos_drop.eval() | |||||
for i in range(0, self.frozen_stages - 1): | |||||
m = self.layers[i] | |||||
m.eval() | |||||
for param in m.parameters(): | |||||
param.requires_grad = False | |||||
def init_weights(self): | |||||
"""Initialize the weights in backbone.""" | |||||
def _init_weights(m): | |||||
if isinstance(m, nn.Linear): | |||||
trunc_normal_(m.weight, std=.02) | |||||
if isinstance(m, nn.Linear) and m.bias is not None: | |||||
nn.init.constant_(m.bias, 0) | |||||
elif isinstance(m, nn.LayerNorm): | |||||
nn.init.constant_(m.bias, 0) | |||||
nn.init.constant_(m.weight, 1.0) | |||||
self.apply(_init_weights) | |||||
def forward(self, x): | |||||
"""Forward function.""" | |||||
x = self.patch_embed(x) | |||||
Wh, Ww = x.size(2), x.size(3) | |||||
if self.ape: | |||||
# interpolate the position embedding to the corresponding size | |||||
absolute_pos_embed = F.interpolate( | |||||
self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') | |||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, | |||||
2) # B Wh*Ww C | |||||
else: | |||||
x = x.flatten(2).transpose(1, 2) | |||||
x = self.pos_drop(x) | |||||
outs = [] | |||||
for i in range(self.num_layers): | |||||
layer = self.layers[i] | |||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) | |||||
if i in self.out_indices: | |||||
norm_layer = getattr(self, f'norm{i}') | |||||
x_out = norm_layer(x_out) | |||||
out = x_out.view(-1, H, W, | |||||
self.num_features[i]).permute(0, 3, 1, | |||||
2).contiguous() | |||||
outs.append(out) | |||||
return tuple(outs) | |||||
def train(self, mode=True): | |||||
"""Convert the model into training mode while keep layers freezed.""" | |||||
super(SwinTransformer, self).train(mode) | |||||
self._freeze_stages() |
@@ -0,0 +1,266 @@ | |||||
import os | |||||
from collections import OrderedDict | |||||
import torch | |||||
import torch.distributed as dist | |||||
import torch.nn as nn | |||||
from modelscope.models.cv.image_instance_segmentation.backbones import \ | |||||
SwinTransformer | |||||
from modelscope.utils.constant import ModelFile | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
def build_backbone(cfg): | |||||
assert isinstance(cfg, dict) | |||||
cfg = cfg.copy() | |||||
type = cfg.pop('type') | |||||
if type == 'SwinTransformer': | |||||
return SwinTransformer(**cfg) | |||||
else: | |||||
raise ValueError(f'backbone \'{type}\' is not supported.') | |||||
def build_neck(cfg): | |||||
assert isinstance(cfg, dict) | |||||
cfg = cfg.copy() | |||||
type = cfg.pop('type') | |||||
if type == 'FPN': | |||||
from mmdet.models import FPN | |||||
return FPN(**cfg) | |||||
else: | |||||
raise ValueError(f'neck \'{type}\' is not supported.') | |||||
def build_rpn_head(cfg): | |||||
assert isinstance(cfg, dict) | |||||
cfg = cfg.copy() | |||||
type = cfg.pop('type') | |||||
if type == 'RPNHead': | |||||
from mmdet.models import RPNHead | |||||
return RPNHead(**cfg) | |||||
else: | |||||
raise ValueError(f'rpn head \'{type}\' is not supported.') | |||||
def build_roi_head(cfg): | |||||
assert isinstance(cfg, dict) | |||||
cfg = cfg.copy() | |||||
type = cfg.pop('type') | |||||
if type == 'CascadeRoIHead': | |||||
from mmdet.models import CascadeRoIHead | |||||
return CascadeRoIHead(**cfg) | |||||
else: | |||||
raise ValueError(f'roi head \'{type}\' is not supported.') | |||||
class CascadeMaskRCNNSwin(nn.Module): | |||||
def __init__(self, | |||||
backbone, | |||||
neck, | |||||
rpn_head, | |||||
roi_head, | |||||
pretrained=None, | |||||
**kwargs): | |||||
""" | |||||
Args: | |||||
backbone (dict): backbone config. | |||||
neck (dict): neck config. | |||||
rpn_head (dict): rpn_head config. | |||||
roi_head (dict): roi_head config. | |||||
pretrained (bool): whether to use pretrained model | |||||
""" | |||||
super(CascadeMaskRCNNSwin, self).__init__() | |||||
self.backbone = build_backbone(backbone) | |||||
self.neck = build_neck(neck) | |||||
self.rpn_head = build_rpn_head(rpn_head) | |||||
self.roi_head = build_roi_head(roi_head) | |||||
self.classes = kwargs.pop('classes', None) | |||||
if pretrained: | |||||
assert 'model_dir' in kwargs, 'pretrained model dir is missing.' | |||||
model_path = os.path.join(kwargs['model_dir'], | |||||
ModelFile.TORCH_MODEL_FILE) | |||||
logger.info(f'loading model from {model_path}') | |||||
weight = torch.load(model_path)['state_dict'] | |||||
tgt_weight = self.state_dict() | |||||
for name in list(weight.keys()): | |||||
if name in tgt_weight: | |||||
load_size = weight[name].size() | |||||
tgt_size = tgt_weight[name].size() | |||||
mis_match = False | |||||
if len(load_size) != len(tgt_size): | |||||
mis_match = True | |||||
else: | |||||
for n1, n2 in zip(load_size, tgt_size): | |||||
if n1 != n2: | |||||
mis_match = True | |||||
break | |||||
if mis_match: | |||||
logger.info(f'size mismatch for {name}, skip loading.') | |||||
del weight[name] | |||||
self.load_state_dict(weight, strict=False) | |||||
logger.info('load model done') | |||||
from mmcv.parallel import DataContainer, scatter | |||||
self.data_container = DataContainer | |||||
self.scatter = scatter | |||||
def extract_feat(self, img): | |||||
x = self.backbone(img) | |||||
x = self.neck(x) | |||||
return x | |||||
def forward_train(self, | |||||
img, | |||||
img_metas, | |||||
gt_bboxes, | |||||
gt_labels, | |||||
gt_bboxes_ignore=None, | |||||
gt_masks=None, | |||||
proposals=None, | |||||
**kwargs): | |||||
""" | |||||
Args: | |||||
img (Tensor): of shape (N, C, H, W) encoding input images. | |||||
Typically these should be mean centered and std scaled. | |||||
img_metas (list[dict]): list of image info dict where each dict | |||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||||
For details on the values of these keys see | |||||
`mmdet/datasets/pipelines/formatting.py:Collect`. | |||||
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |||||
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |||||
gt_labels (list[Tensor]): class indices corresponding to each box | |||||
gt_bboxes_ignore (None | list[Tensor]): specify which bounding | |||||
boxes can be ignored when computing the loss. | |||||
gt_masks (None | Tensor) : true segmentation masks for each box | |||||
used if the architecture supports a segmentation task. | |||||
proposals : override rpn proposals with custom proposals. Use when | |||||
`with_rpn` is False. | |||||
Returns: | |||||
dict[str, Tensor]: a dictionary of loss components | |||||
""" | |||||
x = self.extract_feat(img) | |||||
losses = dict() | |||||
# RPN forward and loss | |||||
proposal_cfg = self.rpn_head.train_cfg.get('rpn_proposal', | |||||
self.rpn_head.test_cfg) | |||||
rpn_losses, proposal_list = self.rpn_head.forward_train( | |||||
x, | |||||
img_metas, | |||||
gt_bboxes, | |||||
gt_labels=None, | |||||
gt_bboxes_ignore=gt_bboxes_ignore, | |||||
proposal_cfg=proposal_cfg, | |||||
**kwargs) | |||||
losses.update(rpn_losses) | |||||
roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, | |||||
gt_bboxes, gt_labels, | |||||
gt_bboxes_ignore, gt_masks, | |||||
**kwargs) | |||||
losses.update(roi_losses) | |||||
return losses | |||||
def forward_test(self, img, img_metas, proposals=None, rescale=True): | |||||
x = self.extract_feat(img) | |||||
if proposals is None: | |||||
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) | |||||
else: | |||||
proposal_list = proposals | |||||
result = self.roi_head.simple_test( | |||||
x, proposal_list, img_metas, rescale=rescale) | |||||
return dict(eval_result=result, img_metas=img_metas) | |||||
def forward(self, img, img_metas, **kwargs): | |||||
# currently only support cpu or single gpu | |||||
if isinstance(img, self.data_container): | |||||
img = img.data[0] | |||||
if isinstance(img_metas, self.data_container): | |||||
img_metas = img_metas.data[0] | |||||
for k, w in kwargs.items(): | |||||
if isinstance(w, self.data_container): | |||||
w = w.data[0] | |||||
kwargs[k] = w | |||||
if next(self.parameters()).is_cuda: | |||||
device = next(self.parameters()).device | |||||
img = self.scatter(img, [device])[0] | |||||
img_metas = self.scatter(img_metas, [device])[0] | |||||
for k, w in kwargs.items(): | |||||
kwargs[k] = self.scatter(w, [device])[0] | |||||
if self.training: | |||||
losses = self.forward_train(img, img_metas, **kwargs) | |||||
loss, log_vars = self._parse_losses(losses) | |||||
outputs = dict( | |||||
loss=loss, log_vars=log_vars, num_samples=len(img_metas)) | |||||
return outputs | |||||
else: | |||||
return self.forward_test(img, img_metas, **kwargs) | |||||
def _parse_losses(self, losses): | |||||
log_vars = OrderedDict() | |||||
for loss_name, loss_value in losses.items(): | |||||
if isinstance(loss_value, torch.Tensor): | |||||
log_vars[loss_name] = loss_value.mean() | |||||
elif isinstance(loss_value, list): | |||||
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) | |||||
else: | |||||
raise TypeError( | |||||
f'{loss_name} is not a tensor or list of tensors') | |||||
loss = sum(_value for _key, _value in log_vars.items() | |||||
if 'loss' in _key) | |||||
log_vars['loss'] = loss | |||||
for loss_name, loss_value in log_vars.items(): | |||||
# reduce loss when distributed training | |||||
if dist.is_available() and dist.is_initialized(): | |||||
loss_value = loss_value.data.clone() | |||||
dist.all_reduce(loss_value.div_(dist.get_world_size())) | |||||
log_vars[loss_name] = loss_value.item() | |||||
return loss, log_vars | |||||
def train_step(self, data, optimizer): | |||||
losses = self(**data) | |||||
loss, log_vars = self._parse_losses(losses) | |||||
outputs = dict( | |||||
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) | |||||
return outputs | |||||
def val_step(self, data, optimizer=None): | |||||
losses = self(**data) | |||||
loss, log_vars = self._parse_losses(losses) | |||||
outputs = dict( | |||||
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) | |||||
return outputs |
@@ -0,0 +1,2 @@ | |||||
from .dataset import ImageInstanceSegmentationCocoDataset | |||||
from .transforms import build_preprocess_transform |
@@ -0,0 +1,332 @@ | |||||
import os.path as osp | |||||
import numpy as np | |||||
from pycocotools.coco import COCO | |||||
from torch.utils.data import Dataset | |||||
class ImageInstanceSegmentationCocoDataset(Dataset): | |||||
"""Coco-style dataset for image instance segmentation. | |||||
Args: | |||||
ann_file (str): Annotation file path. | |||||
classes (Sequence[str], optional): Specify classes to load. | |||||
If is None, ``cls.CLASSES`` will be used. Default: None. | |||||
data_root (str, optional): Data root for ``ann_file``, | |||||
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. | |||||
test_mode (bool, optional): If set True, annotation will not be loaded. | |||||
filter_empty_gt (bool, optional): If set true, images without bounding | |||||
boxes of the dataset's classes will be filtered out. This option | |||||
only works when `test_mode=False`, i.e., we never filter images | |||||
during tests. | |||||
""" | |||||
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', | |||||
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |||||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', | |||||
'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |||||
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', | |||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |||||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', | |||||
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', | |||||
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', | |||||
'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') | |||||
def __init__(self, | |||||
ann_file, | |||||
classes=None, | |||||
data_root=None, | |||||
img_prefix='', | |||||
seg_prefix=None, | |||||
test_mode=False, | |||||
filter_empty_gt=True): | |||||
self.ann_file = ann_file | |||||
self.data_root = data_root | |||||
self.img_prefix = img_prefix | |||||
self.seg_prefix = seg_prefix | |||||
self.test_mode = test_mode | |||||
self.filter_empty_gt = filter_empty_gt | |||||
self.CLASSES = self.get_classes(classes) | |||||
# join paths if data_root is specified | |||||
if self.data_root is not None: | |||||
if not osp.isabs(self.ann_file): | |||||
self.ann_file = osp.join(self.data_root, self.ann_file) | |||||
if not (self.img_prefix is None or osp.isabs(self.img_prefix)): | |||||
self.img_prefix = osp.join(self.data_root, self.img_prefix) | |||||
if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): | |||||
self.seg_prefix = osp.join(self.data_root, self.seg_prefix) | |||||
# load annotations | |||||
self.data_infos = self.load_annotations(self.ann_file) | |||||
# filter images too small and containing no annotations | |||||
if not test_mode: | |||||
valid_inds = self._filter_imgs() | |||||
self.data_infos = [self.data_infos[i] for i in valid_inds] | |||||
# set group flag for the sampler | |||||
self._set_group_flag() | |||||
self.preprocessor = None | |||||
def __len__(self): | |||||
"""Total number of samples of data.""" | |||||
return len(self.data_infos) | |||||
def load_annotations(self, ann_file): | |||||
"""Load annotation from COCO style annotation file. | |||||
Args: | |||||
ann_file (str): Path of annotation file. | |||||
Returns: | |||||
list[dict]: Annotation info from COCO api. | |||||
""" | |||||
self.coco = COCO(ann_file) | |||||
# The order of returned `cat_ids` will not | |||||
# change with the order of the CLASSES | |||||
self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES) | |||||
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} | |||||
self.img_ids = self.coco.getImgIds() | |||||
data_infos = [] | |||||
total_ann_ids = [] | |||||
for i in self.img_ids: | |||||
info = self.coco.loadImgs([i])[0] | |||||
info['filename'] = info['file_name'] | |||||
info['ann_file'] = ann_file | |||||
info['classes'] = self.CLASSES | |||||
data_infos.append(info) | |||||
ann_ids = self.coco.getAnnIds(imgIds=[i]) | |||||
total_ann_ids.extend(ann_ids) | |||||
assert len(set(total_ann_ids)) == len( | |||||
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!" | |||||
return data_infos | |||||
def get_ann_info(self, idx): | |||||
"""Get COCO annotation by index. | |||||
Args: | |||||
idx (int): Index of data. | |||||
Returns: | |||||
dict: Annotation info of specified index. | |||||
""" | |||||
img_id = self.data_infos[idx]['id'] | |||||
ann_ids = self.coco.getAnnIds(imgIds=[img_id]) | |||||
ann_info = self.coco.loadAnns(ann_ids) | |||||
return self._parse_ann_info(self.data_infos[idx], ann_info) | |||||
def get_cat_ids(self, idx): | |||||
"""Get COCO category ids by index. | |||||
Args: | |||||
idx (int): Index of data. | |||||
Returns: | |||||
list[int]: All categories in the image of specified index. | |||||
""" | |||||
img_id = self.data_infos[idx]['id'] | |||||
ann_ids = self.coco.getAnnIds(imgIds=[img_id]) | |||||
ann_info = self.coco.loadAnns(ann_ids) | |||||
return [ann['category_id'] for ann in ann_info] | |||||
def pre_pipeline(self, results): | |||||
"""Prepare results dict for pipeline.""" | |||||
results['img_prefix'] = self.img_prefix | |||||
results['seg_prefix'] = self.seg_prefix | |||||
results['bbox_fields'] = [] | |||||
results['mask_fields'] = [] | |||||
results['seg_fields'] = [] | |||||
def _filter_imgs(self, min_size=32): | |||||
"""Filter images too small or without ground truths.""" | |||||
valid_inds = [] | |||||
# obtain images that contain annotation | |||||
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) | |||||
# obtain images that contain annotations of the required categories | |||||
ids_in_cat = set() | |||||
for i, class_id in enumerate(self.cat_ids): | |||||
ids_in_cat |= set(self.coco.catToImgs[class_id]) | |||||
# merge the image id sets of the two conditions and use the merged set | |||||
# to filter out images if self.filter_empty_gt=True | |||||
ids_in_cat &= ids_with_ann | |||||
valid_img_ids = [] | |||||
for i, img_info in enumerate(self.data_infos): | |||||
img_id = self.img_ids[i] | |||||
if self.filter_empty_gt and img_id not in ids_in_cat: | |||||
continue | |||||
if min(img_info['width'], img_info['height']) >= min_size: | |||||
valid_inds.append(i) | |||||
valid_img_ids.append(img_id) | |||||
self.img_ids = valid_img_ids | |||||
return valid_inds | |||||
def _parse_ann_info(self, img_info, ann_info): | |||||
"""Parse bbox and mask annotation. | |||||
Args: | |||||
ann_info (list[dict]): Annotation info of an image. | |||||
Returns: | |||||
dict: A dict containing the following keys: bboxes, bboxes_ignore,\ | |||||
labels, masks, seg_map. "masks" are raw annotations and not \ | |||||
decoded into binary masks. | |||||
""" | |||||
gt_bboxes = [] | |||||
gt_labels = [] | |||||
gt_bboxes_ignore = [] | |||||
gt_masks_ann = [] | |||||
for i, ann in enumerate(ann_info): | |||||
if ann.get('ignore', False): | |||||
continue | |||||
x1, y1, w, h = ann['bbox'] | |||||
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) | |||||
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) | |||||
if inter_w * inter_h == 0: | |||||
continue | |||||
if ann['area'] <= 0 or w < 1 or h < 1: | |||||
continue | |||||
if ann['category_id'] not in self.cat_ids: | |||||
continue | |||||
bbox = [x1, y1, x1 + w, y1 + h] | |||||
if ann.get('iscrowd', False): | |||||
gt_bboxes_ignore.append(bbox) | |||||
else: | |||||
gt_bboxes.append(bbox) | |||||
gt_labels.append(self.cat2label[ann['category_id']]) | |||||
gt_masks_ann.append(ann.get('segmentation', None)) | |||||
if gt_bboxes: | |||||
gt_bboxes = np.array(gt_bboxes, dtype=np.float32) | |||||
gt_labels = np.array(gt_labels, dtype=np.int64) | |||||
else: | |||||
gt_bboxes = np.zeros((0, 4), dtype=np.float32) | |||||
gt_labels = np.array([], dtype=np.int64) | |||||
if gt_bboxes_ignore: | |||||
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) | |||||
else: | |||||
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) | |||||
seg_map = img_info['filename'].replace('jpg', 'png') | |||||
ann = dict( | |||||
bboxes=gt_bboxes, | |||||
labels=gt_labels, | |||||
bboxes_ignore=gt_bboxes_ignore, | |||||
masks=gt_masks_ann, | |||||
seg_map=seg_map) | |||||
return ann | |||||
def _set_group_flag(self): | |||||
"""Set flag according to image aspect ratio. | |||||
Images with aspect ratio greater than 1 will be set as group 1, | |||||
otherwise group 0. | |||||
""" | |||||
self.flag = np.zeros(len(self), dtype=np.uint8) | |||||
for i in range(len(self)): | |||||
img_info = self.data_infos[i] | |||||
if img_info['width'] / img_info['height'] > 1: | |||||
self.flag[i] = 1 | |||||
def _rand_another(self, idx): | |||||
"""Get another random index from the same group as the given index.""" | |||||
pool = np.where(self.flag == self.flag[idx])[0] | |||||
return np.random.choice(pool) | |||||
def __getitem__(self, idx): | |||||
"""Get training/test data after pipeline. | |||||
Args: | |||||
idx (int): Index of data. | |||||
Returns: | |||||
dict: Training/test data (with annotation if `test_mode` is set \ | |||||
True). | |||||
""" | |||||
if self.test_mode: | |||||
return self.prepare_test_img(idx) | |||||
while True: | |||||
data = self.prepare_train_img(idx) | |||||
if data is None: | |||||
idx = self._rand_another(idx) | |||||
continue | |||||
return data | |||||
def prepare_train_img(self, idx): | |||||
"""Get training data and annotations after pipeline. | |||||
Args: | |||||
idx (int): Index of data. | |||||
Returns: | |||||
dict: Training data and annotation after pipeline with new keys \ | |||||
introduced by pipeline. | |||||
""" | |||||
img_info = self.data_infos[idx] | |||||
ann_info = self.get_ann_info(idx) | |||||
results = dict(img_info=img_info, ann_info=ann_info) | |||||
self.pre_pipeline(results) | |||||
if self.preprocessor is None: | |||||
return results | |||||
self.preprocessor.train() | |||||
return self.preprocessor(results) | |||||
def prepare_test_img(self, idx): | |||||
"""Get testing data after pipeline. | |||||
Args: | |||||
idx (int): Index of data. | |||||
Returns: | |||||
dict: Testing data after pipeline with new keys introduced by \ | |||||
pipeline. | |||||
""" | |||||
img_info = self.data_infos[idx] | |||||
results = dict(img_info=img_info) | |||||
self.pre_pipeline(results) | |||||
if self.preprocessor is None: | |||||
return results | |||||
self.preprocessor.eval() | |||||
results = self.preprocessor(results) | |||||
return results | |||||
@classmethod | |||||
def get_classes(cls, classes=None): | |||||
"""Get class names of current dataset. | |||||
Args: | |||||
classes (Sequence[str] | None): If classes is None, use | |||||
default CLASSES defined by builtin dataset. If classes is | |||||
a tuple or list, override the CLASSES defined by the dataset. | |||||
Returns: | |||||
tuple[str] or list[str]: Names of categories of the dataset. | |||||
""" | |||||
if classes is None: | |||||
return cls.CLASSES | |||||
if isinstance(classes, (tuple, list)): | |||||
class_names = classes | |||||
else: | |||||
raise ValueError(f'Unsupported type {type(classes)} of classes.') | |||||
return class_names | |||||
def to_torch_dataset(self, preprocessors=None): | |||||
self.preprocessor = preprocessors | |||||
return self |
@@ -0,0 +1,109 @@ | |||||
import os.path as osp | |||||
import numpy as np | |||||
from modelscope.fileio import File | |||||
def build_preprocess_transform(cfg): | |||||
assert isinstance(cfg, dict) | |||||
cfg = cfg.copy() | |||||
type = cfg.pop('type') | |||||
if type == 'LoadImageFromFile': | |||||
return LoadImageFromFile(**cfg) | |||||
elif type == 'LoadAnnotations': | |||||
from mmdet.datasets.pipelines import LoadAnnotations | |||||
return LoadAnnotations(**cfg) | |||||
elif type == 'Resize': | |||||
if 'img_scale' in cfg: | |||||
if isinstance(cfg.img_scale[0], list): | |||||
elems = [] | |||||
for elem in cfg.img_scale: | |||||
elems.append(tuple(elem)) | |||||
cfg.img_scale = elems | |||||
else: | |||||
cfg.img_scale = tuple(cfg.img_scale) | |||||
from mmdet.datasets.pipelines import Resize | |||||
return Resize(**cfg) | |||||
elif type == 'RandomFlip': | |||||
from mmdet.datasets.pipelines import RandomFlip | |||||
return RandomFlip(**cfg) | |||||
elif type == 'Normalize': | |||||
from mmdet.datasets.pipelines import Normalize | |||||
return Normalize(**cfg) | |||||
elif type == 'Pad': | |||||
from mmdet.datasets.pipelines import Pad | |||||
return Pad(**cfg) | |||||
elif type == 'DefaultFormatBundle': | |||||
from mmdet.datasets.pipelines import DefaultFormatBundle | |||||
return DefaultFormatBundle(**cfg) | |||||
elif type == 'ImageToTensor': | |||||
from mmdet.datasets.pipelines import ImageToTensor | |||||
return ImageToTensor(**cfg) | |||||
elif type == 'Collect': | |||||
from mmdet.datasets.pipelines import Collect | |||||
return Collect(**cfg) | |||||
else: | |||||
raise ValueError(f'preprocess transform \'{type}\' is not supported.') | |||||
class LoadImageFromFile: | |||||
"""Load an image from file. | |||||
Required keys are "img_prefix" and "img_info" (a dict that must contain the | |||||
key "filename"). Added or updated keys are "filename", "img", "img_shape", | |||||
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), | |||||
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). | |||||
Args: | |||||
to_float32 (bool): Whether to convert the loaded image to a float32 | |||||
numpy array. If set to False, the loaded image is an uint8 array. | |||||
Defaults to False. | |||||
""" | |||||
def __init__(self, to_float32=False, mode='rgb'): | |||||
self.to_float32 = to_float32 | |||||
self.mode = mode | |||||
from mmcv import imfrombytes | |||||
self.imfrombytes = imfrombytes | |||||
def __call__(self, results): | |||||
"""Call functions to load image and get image meta information. | |||||
Args: | |||||
results (dict): Result dict from :obj:`ImageInstanceSegmentationDataset`. | |||||
Returns: | |||||
dict: The dict contains loaded image and meta information. | |||||
""" | |||||
if results['img_prefix'] is not None: | |||||
filename = osp.join(results['img_prefix'], | |||||
results['img_info']['filename']) | |||||
else: | |||||
filename = results['img_info']['filename'] | |||||
img_bytes = File.read(filename) | |||||
img = self.imfrombytes(img_bytes, 'color', 'bgr', backend='pillow') | |||||
if self.to_float32: | |||||
img = img.astype(np.float32) | |||||
results['filename'] = filename | |||||
results['ori_filename'] = results['img_info']['filename'] | |||||
results['img'] = img | |||||
results['img_shape'] = img.shape | |||||
results['ori_shape'] = img.shape | |||||
results['img_fields'] = ['img'] | |||||
results['ann_file'] = results['img_info']['ann_file'] | |||||
results['classes'] = results['img_info']['classes'] | |||||
return results | |||||
def __repr__(self): | |||||
repr_str = (f'{self.__class__.__name__}(' | |||||
f'to_float32={self.to_float32}, ' | |||||
f"mode='{self.mode}'") | |||||
return repr_str |
@@ -0,0 +1,49 @@ | |||||
import os | |||||
from typing import Any, Dict | |||||
import torch | |||||
from modelscope.metainfo import Models | |||||
from modelscope.models.base import TorchModel | |||||
from modelscope.models.builder import MODELS | |||||
from modelscope.models.cv.image_instance_segmentation import \ | |||||
CascadeMaskRCNNSwin | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
@MODELS.register_module( | |||||
Tasks.image_segmentation, module_name=Models.cascade_mask_rcnn_swin) | |||||
class CascadeMaskRCNNSwinModel(TorchModel): | |||||
def __init__(self, model_dir=None, *args, **kwargs): | |||||
""" | |||||
Args: | |||||
model_dir (str): model directory. | |||||
""" | |||||
super(CascadeMaskRCNNSwinModel, self).__init__( | |||||
model_dir=model_dir, *args, **kwargs) | |||||
if 'backbone' not in kwargs: | |||||
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
cfg = Config.from_file(config_path) | |||||
model_cfg = cfg.model | |||||
kwargs.update(model_cfg) | |||||
self.model = CascadeMaskRCNNSwin(model_dir=model_dir, **kwargs) | |||||
self.device = torch.device( | |||||
'cuda' if torch.cuda.is_available() else 'cpu') | |||||
self.model.to(self.device) | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
output = self.model(**input) | |||||
return output | |||||
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||||
return input | |||||
def compute_loss(self, outputs: Dict[str, Any], labels): | |||||
pass |
@@ -0,0 +1,203 @@ | |||||
import itertools | |||||
import cv2 | |||||
import numpy as np | |||||
import pycocotools.mask as maskUtils | |||||
import torch | |||||
from modelscope.outputs import OutputKeys | |||||
def get_seg_bboxes(bboxes, labels, segms=None, class_names=None, score_thr=0.): | |||||
assert bboxes.ndim == 2, \ | |||||
f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.' | |||||
assert labels.ndim == 1, \ | |||||
f' labels ndim should be 1, but its ndim is {labels.ndim}.' | |||||
assert bboxes.shape[0] == labels.shape[0], \ | |||||
'bboxes.shape[0] and labels.shape[0] should have the same length.' | |||||
assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \ | |||||
f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.' | |||||
if score_thr > 0: | |||||
assert bboxes.shape[1] == 5 | |||||
scores = bboxes[:, -1] | |||||
inds = scores > score_thr | |||||
bboxes = bboxes[inds, :] | |||||
labels = labels[inds] | |||||
if segms is not None: | |||||
segms = segms[inds, ...] | |||||
bboxes_names = [] | |||||
for i, (bbox, label) in enumerate(zip(bboxes, labels)): | |||||
label_name = class_names[ | |||||
label] if class_names is not None else f'class {label}' | |||||
bbox = [0 if b < 0 else b for b in list(bbox)] | |||||
bbox.append(label_name) | |||||
bbox.append(segms[i].astype(bool)) | |||||
bboxes_names.append(bbox) | |||||
return bboxes_names | |||||
def get_img_seg_results(det_rawdata=None, | |||||
class_names=None, | |||||
score_thr=0.3, | |||||
is_decode=True): | |||||
''' | |||||
Get all boxes of one image. | |||||
score_thr: Classification probability threshold。 | |||||
output format: [ [x1,y1,x2,y2, prob, cls_name, mask], [x1,y1,x2,y2, prob, cls_name, mask], ... ] | |||||
''' | |||||
assert det_rawdata is not None, 'det_rawdata should be not None.' | |||||
assert class_names is not None, 'class_names should be not None.' | |||||
if isinstance(det_rawdata, tuple): | |||||
bbox_result, segm_result = det_rawdata | |||||
if isinstance(segm_result, tuple): | |||||
segm_result = segm_result[0] # ms rcnn | |||||
else: | |||||
bbox_result, segm_result = det_rawdata, None | |||||
bboxes = np.vstack(bbox_result) | |||||
labels = [ | |||||
np.full(bbox.shape[0], i, dtype=np.int32) | |||||
for i, bbox in enumerate(bbox_result) | |||||
] | |||||
labels = np.concatenate(labels) | |||||
segms = None | |||||
if segm_result is not None and len(labels) > 0: # non empty | |||||
segms = list(itertools.chain(*segm_result)) | |||||
if is_decode: | |||||
segms = maskUtils.decode(segms) | |||||
segms = segms.transpose(2, 0, 1) | |||||
if isinstance(segms[0], torch.Tensor): | |||||
segms = torch.stack(segms, dim=0).detach().cpu().numpy() | |||||
else: | |||||
segms = np.stack(segms, axis=0) | |||||
bboxes_names = get_seg_bboxes( | |||||
bboxes, | |||||
labels, | |||||
segms=segms, | |||||
class_names=class_names, | |||||
score_thr=score_thr) | |||||
return bboxes_names | |||||
def get_img_ins_seg_result(img_seg_result=None, | |||||
class_names=None, | |||||
score_thr=0.3): | |||||
assert img_seg_result is not None, 'img_seg_result should be not None.' | |||||
assert class_names is not None, 'class_names should be not None.' | |||||
img_seg_result = get_img_seg_results( | |||||
det_rawdata=(img_seg_result[0], img_seg_result[1]), | |||||
class_names=class_names, | |||||
score_thr=score_thr, | |||||
is_decode=False) | |||||
results_dict = { | |||||
OutputKeys.BOXES: [], | |||||
OutputKeys.MASKS: [], | |||||
OutputKeys.LABELS: [], | |||||
OutputKeys.SCORES: [] | |||||
} | |||||
for seg_result in img_seg_result: | |||||
box = { | |||||
'x': np.int(seg_result[0]), | |||||
'y': np.int(seg_result[1]), | |||||
'w': np.int(seg_result[2] - seg_result[0]), | |||||
'h': np.int(seg_result[3] - seg_result[1]) | |||||
} | |||||
score = np.float(seg_result[4]) | |||||
category = seg_result[5] | |||||
mask = np.array(seg_result[6], order='F', dtype='uint8') | |||||
mask = mask.astype(np.float) | |||||
results_dict[OutputKeys.BOXES].append(box) | |||||
results_dict[OutputKeys.MASKS].append(mask) | |||||
results_dict[OutputKeys.SCORES].append(score) | |||||
results_dict[OutputKeys.LABELS].append(category) | |||||
return results_dict | |||||
def show_result( | |||||
img, | |||||
result, | |||||
out_file='result.jpg', | |||||
show_box=True, | |||||
show_label=True, | |||||
show_score=True, | |||||
alpha=0.5, | |||||
fontScale=0.5, | |||||
fontFace=cv2.FONT_HERSHEY_COMPLEX_SMALL, | |||||
thickness=1, | |||||
): | |||||
assert isinstance(img, (str, np.ndarray)), \ | |||||
f'img must be str or np.ndarray, but got {type(img)}.' | |||||
if isinstance(img, str): | |||||
img = cv2.imread(img) | |||||
if len(img.shape) == 2: | |||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
img = img.astype(np.float32) | |||||
labels = result[OutputKeys.LABELS] | |||||
scores = result[OutputKeys.SCORES] | |||||
boxes = result[OutputKeys.BOXES] | |||||
masks = result[OutputKeys.MASKS] | |||||
for label, score, box, mask in zip(labels, scores, boxes, masks): | |||||
random_color = np.array([ | |||||
np.random.random() * 255.0, | |||||
np.random.random() * 255.0, | |||||
np.random.random() * 255.0 | |||||
]) | |||||
x1 = int(box['x']) | |||||
y1 = int(box['y']) | |||||
w = int(box['w']) | |||||
h = int(box['h']) | |||||
x2 = x1 + w | |||||
y2 = y1 + h | |||||
if show_box: | |||||
cv2.rectangle( | |||||
img, (x1, y1), (x2, y2), random_color, thickness=thickness) | |||||
if show_label or show_score: | |||||
if show_label and show_score: | |||||
text = '{}|{}'.format(label, round(float(score), 2)) | |||||
elif show_label: | |||||
text = '{}'.format(label) | |||||
else: | |||||
text = '{}'.format(round(float(score), 2)) | |||||
retval, baseLine = cv2.getTextSize( | |||||
text, | |||||
fontFace=fontFace, | |||||
fontScale=fontScale, | |||||
thickness=thickness) | |||||
cv2.rectangle( | |||||
img, (x1, y1 - retval[1] - baseLine), (x1 + retval[0], y1), | |||||
thickness=-1, | |||||
color=(0, 0, 0)) | |||||
cv2.putText( | |||||
img, | |||||
text, (x1, y1 - baseLine), | |||||
fontScale=fontScale, | |||||
fontFace=fontFace, | |||||
thickness=thickness, | |||||
color=random_color) | |||||
idx = np.nonzero(mask) | |||||
img[idx[0], idx[1], :] *= 1.0 - alpha | |||||
img[idx[0], idx[1], :] += alpha * random_color | |||||
cv2.imwrite(out_file, img) |
@@ -13,6 +13,7 @@ class OutputKeys(object): | |||||
POSES = 'poses' | POSES = 'poses' | ||||
CAPTION = 'caption' | CAPTION = 'caption' | ||||
BOXES = 'boxes' | BOXES = 'boxes' | ||||
MASKS = 'masks' | |||||
TEXT = 'text' | TEXT = 'text' | ||||
POLYGONS = 'polygons' | POLYGONS = 'polygons' | ||||
OUTPUT = 'output' | OUTPUT = 'output' | ||||
@@ -76,6 +76,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
'damo/cv_daflow_virtual-tryon_base'), | 'damo/cv_daflow_virtual-tryon_base'), | ||||
Tasks.image_colorization: (Pipelines.image_colorization, | Tasks.image_colorization: (Pipelines.image_colorization, | ||||
'damo/cv_unet_image-colorization'), | 'damo/cv_unet_image-colorization'), | ||||
Tasks.image_segmentation: | |||||
(Pipelines.image_instance_segmentation, | |||||
'damo/cv_swin-b_image-instance-segmentation_coco'), | |||||
Tasks.style_transfer: (Pipelines.style_transfer, | Tasks.style_transfer: (Pipelines.style_transfer, | ||||
'damo/cv_aams_style-transfer_damo'), | 'damo/cv_aams_style-transfer_damo'), | ||||
Tasks.face_image_generation: (Pipelines.face_image_generation, | Tasks.face_image_generation: (Pipelines.face_image_generation, | ||||
@@ -11,6 +11,7 @@ try: | |||||
from .image_colorization_pipeline import ImageColorizationPipeline | from .image_colorization_pipeline import ImageColorizationPipeline | ||||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | ||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline | from .face_image_generation_pipeline import FaceImageGenerationPipeline | ||||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||||
except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
if str(e) == "No module named 'torch'": | if str(e) == "No module named 'torch'": | ||||
pass | pass | ||||
@@ -0,0 +1,105 @@ | |||||
import os | |||||
from typing import Any, Dict, Optional, Union | |||||
import cv2 | |||||
import numpy as np | |||||
import torch | |||||
from PIL import Image | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.cv.image_instance_segmentation.model import \ | |||||
CascadeMaskRCNNSwinModel | |||||
from modelscope.models.cv.image_instance_segmentation.postprocess_utils import \ | |||||
get_img_ins_seg_result | |||||
from modelscope.pipelines.base import Input, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import (ImageInstanceSegmentationPreprocessor, | |||||
build_preprocessor, load_image) | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import Fields, ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.image_segmentation, | |||||
module_name=Pipelines.image_instance_segmentation) | |||||
class ImageInstanceSegmentationPipeline(Pipeline): | |||||
def __init__(self, | |||||
model: Union[CascadeMaskRCNNSwinModel, str], | |||||
preprocessor: Optional[ | |||||
ImageInstanceSegmentationPreprocessor] = None, | |||||
**kwargs): | |||||
"""use `model` and `preprocessor` to create a image instance segmentation pipeline for prediction | |||||
Args: | |||||
model (CascadeMaskRCNNSwinModel | str): a model instance | |||||
preprocessor (CascadeMaskRCNNSwinPreprocessor | None): a preprocessor instance | |||||
""" | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
if preprocessor is None: | |||||
config_path = os.path.join(self.model.model_dir, | |||||
ModelFile.CONFIGURATION) | |||||
cfg = Config.from_file(config_path) | |||||
self.preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) | |||||
else: | |||||
self.preprocessor = preprocessor | |||||
self.preprocessor.eval() | |||||
self.model.eval() | |||||
def _collate_fn(self, data): | |||||
# don't require collating | |||||
return data | |||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: | |||||
filename = None | |||||
img = None | |||||
if isinstance(input, str): | |||||
filename = input | |||||
img = np.array(load_image(input)) | |||||
img = img[:, :, ::-1] # convert to bgr | |||||
elif isinstance(input, Image.Image): | |||||
img = np.array(input.convert('RGB')) | |||||
img = img[:, :, ::-1] # convert to bgr | |||||
elif isinstance(input, np.ndarray): | |||||
if len(input.shape) == 2: | |||||
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||||
else: | |||||
raise TypeError(f'input should be either str, PIL.Image,' | |||||
f' np.array, but got {type(input)}') | |||||
result = { | |||||
'img': img, | |||||
'img_shape': img.shape, | |||||
'ori_shape': img.shape, | |||||
'img_fields': ['img'], | |||||
'img_prefix': '', | |||||
'img_info': { | |||||
'filename': filename, | |||||
'ann_file': None, | |||||
'classes': None | |||||
}, | |||||
} | |||||
result = self.preprocessor(result) | |||||
# stacked as a batch | |||||
result['img'] = torch.stack([result['img']], dim=0) | |||||
result['img_metas'] = [result['img_metas'].data] | |||||
return result | |||||
def forward(self, input: Dict[str, Any], | |||||
**forward_params) -> Dict[str, Any]: | |||||
with torch.no_grad(): | |||||
output = self.model(input) | |||||
return output | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
result = get_img_ins_seg_result( | |||||
img_seg_result=inputs['eval_result'][0], | |||||
class_names=self.model.model.classes) | |||||
return result |
@@ -20,6 +20,7 @@ try: | |||||
from .space.dialog_modeling_preprocessor import * # noqa F403 | from .space.dialog_modeling_preprocessor import * # noqa F403 | ||||
from .space.dialog_state_tracking_preprocessor import * # noqa F403 | from .space.dialog_state_tracking_preprocessor import * # noqa F403 | ||||
from .image import ImageColorEnhanceFinetunePreprocessor | from .image import ImageColorEnhanceFinetunePreprocessor | ||||
from .image import ImageInstanceSegmentationPreprocessor | |||||
except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
if str(e) == "No module named 'tensorflow'": | if str(e) == "No module named 'tensorflow'": | ||||
print(TENSORFLOW_IMPORT_ERROR.format('tts')) | print(TENSORFLOW_IMPORT_ERROR.format('tts')) | ||||
@@ -136,3 +136,72 @@ class ImageColorEnhanceFinetunePreprocessor(Preprocessor): | |||||
""" | """ | ||||
return data | return data | ||||
@PREPROCESSORS.register_module( | |||||
Fields.cv, | |||||
module_name=Preprocessors.image_instance_segmentation_preprocessor) | |||||
class ImageInstanceSegmentationPreprocessor(Preprocessor): | |||||
def __init__(self, *args, **kwargs): | |||||
"""image instance segmentation preprocessor in the fine-tune scenario | |||||
""" | |||||
super().__init__(*args, **kwargs) | |||||
self.training = kwargs.pop('training', True) | |||||
self.preprocessor_train_cfg = kwargs.pop('train', None) | |||||
self.preprocessor_test_cfg = kwargs.pop('val', None) | |||||
self.train_transforms = [] | |||||
self.test_transforms = [] | |||||
from modelscope.models.cv.image_instance_segmentation.datasets import \ | |||||
build_preprocess_transform | |||||
if self.preprocessor_train_cfg is not None: | |||||
if isinstance(self.preprocessor_train_cfg, dict): | |||||
self.preprocessor_train_cfg = [self.preprocessor_train_cfg] | |||||
for cfg in self.preprocessor_train_cfg: | |||||
transform = build_preprocess_transform(cfg) | |||||
self.train_transforms.append(transform) | |||||
if self.preprocessor_test_cfg is not None: | |||||
if isinstance(self.preprocessor_test_cfg, dict): | |||||
self.preprocessor_test_cfg = [self.preprocessor_test_cfg] | |||||
for cfg in self.preprocessor_test_cfg: | |||||
transform = build_preprocess_transform(cfg) | |||||
self.test_transforms.append(transform) | |||||
def train(self): | |||||
self.training = True | |||||
return | |||||
def eval(self): | |||||
self.training = False | |||||
return | |||||
@type_assert(object, object) | |||||
def __call__(self, results: Dict[str, Any]): | |||||
"""process the raw input data | |||||
Args: | |||||
results (dict): Result dict from loading pipeline. | |||||
Returns: | |||||
Dict[str, Any] | None: the preprocessed data | |||||
""" | |||||
if self.training: | |||||
transforms = self.train_transforms | |||||
else: | |||||
transforms = self.test_transforms | |||||
for t in transforms: | |||||
results = t(results) | |||||
if results is None: | |||||
return None | |||||
return results |
@@ -1,4 +1,5 @@ | |||||
from .base import DummyTrainer | from .base import DummyTrainer | ||||
from .builder import build_trainer | from .builder import build_trainer | ||||
from .cv import ImageInstanceSegmentationTrainer | |||||
from .nlp import SequenceClassificationTrainer | from .nlp import SequenceClassificationTrainer | ||||
from .trainer import EpochBasedTrainer | from .trainer import EpochBasedTrainer |
@@ -0,0 +1,2 @@ | |||||
from .image_instance_segmentation_trainer import \ | |||||
ImageInstanceSegmentationTrainer |
@@ -0,0 +1,27 @@ | |||||
from modelscope.trainers.builder import TRAINERS | |||||
from modelscope.trainers.trainer import EpochBasedTrainer | |||||
@TRAINERS.register_module(module_name='image-instance-segmentation') | |||||
class ImageInstanceSegmentationTrainer(EpochBasedTrainer): | |||||
def __init__(self, *args, **kwargs): | |||||
super().__init__(*args, **kwargs) | |||||
def collate_fn(self, data): | |||||
# we skip this func due to some special data type, e.g., BitmapMasks | |||||
return data | |||||
def train(self, *args, **kwargs): | |||||
super().train(*args, **kwargs) | |||||
def evaluate(self, *args, **kwargs): | |||||
metric_values = super().evaluate(*args, **kwargs) | |||||
return metric_values | |||||
def prediction_step(self, model, inputs): | |||||
pass | |||||
def to_task_dataset(self, datasets, mode, preprocessor=None): | |||||
# wait for dataset interface to become stable... | |||||
return datasets.to_torch_dataset(preprocessor) |
@@ -0,0 +1,60 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import unittest | |||||
from modelscope.hub.snapshot_download import snapshot_download | |||||
from modelscope.models import Model | |||||
from modelscope.models.cv.image_instance_segmentation.model import \ | |||||
CascadeMaskRCNNSwinModel | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import ImageInstanceSegmentationPipeline, pipeline | |||||
from modelscope.preprocessors import build_preprocessor | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import Fields, ModelFile, Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
class ImageInstanceSegmentationTest(unittest.TestCase): | |||||
model_id = 'damo/cv_swin-b_image-instance-segmentation_coco' | |||||
image = 'data/test/images/image_instance_segmentation.jpg' | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_from_modelhub(self): | |||||
model = Model.from_pretrained(self.model_id) | |||||
config_path = os.path.join(model.model_dir, ModelFile.CONFIGURATION) | |||||
cfg = Config.from_file(config_path) | |||||
preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) | |||||
pipeline_ins = pipeline( | |||||
task=Tasks.image_segmentation, | |||||
model=model, | |||||
preprocessor=preprocessor) | |||||
print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_run_with_model_name(self): | |||||
pipeline_ins = pipeline( | |||||
task=Tasks.image_segmentation, model=self.model_id) | |||||
print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_default_model(self): | |||||
pipeline_ins = pipeline(task=Tasks.image_segmentation) | |||||
print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_by_direct_model_download(self): | |||||
cache_path = snapshot_download(self.model_id) | |||||
config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) | |||||
cfg = Config.from_file(config_path) | |||||
preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) | |||||
model = CascadeMaskRCNNSwinModel(cache_path) | |||||
pipeline1 = ImageInstanceSegmentationPipeline( | |||||
model, preprocessor=preprocessor) | |||||
pipeline2 = pipeline( | |||||
Tasks.image_segmentation, model=model, preprocessor=preprocessor) | |||||
print(f'pipeline1:{pipeline1(input=self.image)[OutputKeys.LABELS]}') | |||||
print(f'pipeline2: {pipeline2(input=self.image)[OutputKeys.LABELS]}') | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,117 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import shutil | |||||
import tempfile | |||||
import unittest | |||||
import zipfile | |||||
from functools import partial | |||||
from modelscope.hub.snapshot_download import snapshot_download | |||||
from modelscope.models.cv.image_instance_segmentation import \ | |||||
CascadeMaskRCNNSwinModel | |||||
from modelscope.models.cv.image_instance_segmentation.datasets import \ | |||||
ImageInstanceSegmentationCocoDataset | |||||
from modelscope.trainers import build_trainer | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import ModelFile | |||||
from modelscope.utils.test_utils import test_level | |||||
class TestImageInstanceSegmentationTrainer(unittest.TestCase): | |||||
model_id = 'damo/cv_swin-b_image-instance-segmentation_coco' | |||||
def setUp(self): | |||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
cache_path = snapshot_download(self.model_id) | |||||
config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) | |||||
cfg = Config.from_file(config_path) | |||||
data_root = cfg.dataset.data_root | |||||
classes = tuple(cfg.dataset.classes) | |||||
max_epochs = cfg.train.max_epochs | |||||
samples_per_gpu = cfg.train.dataloader.batch_size_per_gpu | |||||
if data_root is None: | |||||
# use default toy data | |||||
dataset_path = os.path.join(cache_path, 'toydata.zip') | |||||
with zipfile.ZipFile(dataset_path, 'r') as zipf: | |||||
zipf.extractall(cache_path) | |||||
data_root = cache_path + '/toydata/' | |||||
classes = ('Cat', 'Dog') | |||||
self.train_dataset = ImageInstanceSegmentationCocoDataset( | |||||
data_root + 'annotations/instances_train.json', | |||||
classes=classes, | |||||
data_root=data_root, | |||||
img_prefix=data_root + 'images/train/', | |||||
seg_prefix=None, | |||||
test_mode=False) | |||||
self.eval_dataset = ImageInstanceSegmentationCocoDataset( | |||||
data_root + 'annotations/instances_val.json', | |||||
classes=classes, | |||||
data_root=data_root, | |||||
img_prefix=data_root + 'images/val/', | |||||
seg_prefix=None, | |||||
test_mode=True) | |||||
from mmcv.parallel import collate | |||||
self.collate_fn = partial(collate, samples_per_gpu=samples_per_gpu) | |||||
self.max_epochs = max_epochs | |||||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
if not os.path.exists(self.tmp_dir): | |||||
os.makedirs(self.tmp_dir) | |||||
def tearDown(self): | |||||
shutil.rmtree(self.tmp_dir) | |||||
super().tearDown() | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_trainer(self): | |||||
kwargs = dict( | |||||
model=self.model_id, | |||||
data_collator=self.collate_fn, | |||||
train_dataset=self.train_dataset, | |||||
eval_dataset=self.eval_dataset, | |||||
work_dir=self.tmp_dir) | |||||
trainer = build_trainer( | |||||
name='image-instance-segmentation', default_args=kwargs) | |||||
trainer.train() | |||||
results_files = os.listdir(self.tmp_dir) | |||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
for i in range(self.max_epochs): | |||||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_trainer_with_model_and_args(self): | |||||
tmp_dir = tempfile.TemporaryDirectory().name | |||||
if not os.path.exists(tmp_dir): | |||||
os.makedirs(tmp_dir) | |||||
cache_path = snapshot_download(self.model_id) | |||||
model = CascadeMaskRCNNSwinModel.from_pretrained(cache_path) | |||||
kwargs = dict( | |||||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||||
model=model, | |||||
data_collator=self.collate_fn, | |||||
train_dataset=self.train_dataset, | |||||
eval_dataset=self.eval_dataset, | |||||
work_dir=self.tmp_dir) | |||||
trainer = build_trainer( | |||||
name='image-instance-segmentation', default_args=kwargs) | |||||
trainer.train() | |||||
results_files = os.listdir(self.tmp_dir) | |||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
for i in range(self.max_epochs): | |||||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
if __name__ == '__main__': | |||||
unittest.main() |