Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8967432 * create ocr_detection task * replace c++ nms with python version * replace c++ decoder with python version * add requirements for ocr_detectionmaster
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:5c8435db5583400be5d11a2c17910c96133b462c8a99ccaf0e19f4aac34e0a94 | |||
size 141149 |
@@ -28,6 +28,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.image_generation: | |||
('person-image-cartoon', | |||
'damo/cv_unet_person-image-cartoon_compound-models'), | |||
Tasks.ocr_detection: ('ocr-detection', | |||
'damo/cv_resnet18_ocr-detection-line-level_damo'), | |||
} | |||
@@ -1,2 +1,3 @@ | |||
from .image_cartoon_pipeline import ImageCartoonPipeline | |||
from .image_matting_pipeline import ImageMattingPipeline | |||
from .ocr_detection_pipeline import OCRDetectionPipeline |
@@ -0,0 +1,167 @@ | |||
import math | |||
import os | |||
import os.path as osp | |||
import sys | |||
from typing import Any, Dict, List, Tuple, Union | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import tensorflow as tf | |||
import tf_slim as slim | |||
from modelscope.pipelines.base import Input | |||
from modelscope.preprocessors import load_image | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from ..base import Pipeline | |||
from ..builder import PIPELINES | |||
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils | |||
if tf.__version__ >= '2.0': | |||
tf = tf.compat.v1 | |||
tf.compat.v1.disable_eager_execution() | |||
logger = get_logger() | |||
# constant | |||
RBOX_DIM = 5 | |||
OFFSET_DIM = 6 | |||
WORD_POLYGON_DIM = 8 | |||
OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1] | |||
FLAGS = tf.app.flags.FLAGS | |||
tf.app.flags.DEFINE_float('node_threshold', 0.4, | |||
'Confidence threshold for nodes') | |||
tf.app.flags.DEFINE_float('link_threshold', 0.6, | |||
'Confidence threshold for links') | |||
@PIPELINES.register_module( | |||
Tasks.ocr_detection, module_name=Tasks.ocr_detection) | |||
class OCRDetectionPipeline(Pipeline): | |||
def __init__(self, model: str): | |||
super().__init__(model=model) | |||
model_path = osp.join( | |||
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), | |||
'checkpoint-80000') | |||
config = tf.ConfigProto(allow_soft_placement=True) | |||
config.gpu_options.allow_growth = True | |||
self._session = tf.Session(config=config) | |||
global_step = tf.get_variable( | |||
'global_step', [], | |||
initializer=tf.constant_initializer(0), | |||
dtype=tf.int64, | |||
trainable=False) | |||
variable_averages = tf.train.ExponentialMovingAverage( | |||
0.997, global_step) | |||
self.input_images = tf.placeholder( | |||
tf.float32, shape=[1, 1024, 1024, 3], name='input_images') | |||
self.output = {} | |||
# detector | |||
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector() | |||
all_maps = detector.build_model(self.input_images, is_training=False) | |||
# decode local predictions | |||
all_nodes, all_links, all_reg = [], [], [] | |||
for i, maps in enumerate(all_maps): | |||
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] | |||
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) | |||
cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) | |||
lnk_prob_pos = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, :2]) | |||
lnk_prob_mut = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, 2:]) | |||
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) | |||
all_nodes.append(cls_prob) | |||
all_links.append(lnk_prob) | |||
all_reg.append(reg_maps) | |||
# decode segments and links | |||
image_size = tf.shape(self.input_images)[1:3] | |||
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python( | |||
image_size, | |||
all_nodes, | |||
all_links, | |||
all_reg, | |||
anchor_sizes=list(detector.anchor_sizes)) | |||
# combine segments | |||
combined_rboxes, combined_counts = ops.combine_segments_python( | |||
segments, group_indices, segment_counts) | |||
self.output['combined_rboxes'] = combined_rboxes | |||
self.output['combined_counts'] = combined_counts | |||
with self._session.as_default() as sess: | |||
logger.info(f'loading model from {model_path}') | |||
# load model | |||
model_loader = tf.train.Saver( | |||
variable_averages.variables_to_restore()) | |||
model_loader.restore(sess, model_path) | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = np.array(load_image(input)) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = np.array(input.convert('RGB')) | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] # in rgb order | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
h, w, c = img.shape | |||
img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) | |||
img_pad[:h, :w, :] = img | |||
resize_size = 1024 | |||
img_pad_resize = cv2.resize(img_pad, (resize_size, resize_size)) | |||
img_pad_resize = cv2.cvtColor(img_pad_resize, cv2.COLOR_RGB2BGR) | |||
img_pad_resize = img_pad_resize - np.array([123.68, 116.78, 103.94], | |||
dtype=np.float32) | |||
resize_size = tf.stack([resize_size, resize_size]) | |||
orig_size = tf.stack([max(h, w), max(h, w)]) | |||
self.output['orig_size'] = orig_size | |||
self.output['resize_size'] = resize_size | |||
result = {'img': np.expand_dims(img_pad_resize, axis=0)} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
with self._session.as_default(): | |||
feed_dict = {self.input_images: input['img']} | |||
sess_outputs = self._session.run(self.output, feed_dict=feed_dict) | |||
return sess_outputs | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
rboxes = inputs['combined_rboxes'][0] | |||
count = inputs['combined_counts'][0] | |||
rboxes = rboxes[:count, :] | |||
# convert rboxes to polygons and find its coordinates on the original image | |||
orig_h, orig_w = inputs['orig_size'] | |||
resize_h, resize_w = inputs['resize_size'] | |||
polygons = utils.rboxes_to_polygons(rboxes) | |||
scale_y = float(orig_h) / float(resize_h) | |||
scale_x = float(orig_w) / float(resize_w) | |||
# confine polygons inside image | |||
polygons[:, ::2] = np.maximum( | |||
0, np.minimum(polygons[:, ::2] * scale_x, orig_w - 1)) | |||
polygons[:, 1::2] = np.maximum( | |||
0, np.minimum(polygons[:, 1::2] * scale_y, orig_h - 1)) | |||
polygons = np.round(polygons).astype(np.int32) | |||
# nms | |||
dt_n9 = [o + [utils.cal_width(o)] for o in polygons.tolist()] | |||
dt_nms = utils.nms_python(dt_n9) | |||
dt_polygons = np.array([o[:8] for o in dt_nms]) | |||
result = {'det_polygons': dt_polygons} | |||
return result |
@@ -0,0 +1,158 @@ | |||
import tensorflow as tf | |||
import tf_slim as slim | |||
from . import ops, resnet18_v1, resnet_utils | |||
if tf.__version__ >= '2.0': | |||
tf = tf.compat.v1 | |||
# constants | |||
OFFSET_DIM = 6 | |||
N_LOCAL_LINKS = 8 | |||
N_CROSS_LINKS = 4 | |||
N_SEG_CLASSES = 2 | |||
N_LNK_CLASSES = 4 | |||
POS_LABEL = 1 | |||
NEG_LABEL = 0 | |||
class SegLinkDetector(): | |||
def __init__(self): | |||
self.anchor_sizes = [6., 11.84210526, 23.68421053, 45., 90., 150.] | |||
def _detection_classifier(self, | |||
maps, | |||
ksize, | |||
weight_decay, | |||
cross_links=False, | |||
scope=None): | |||
with tf.variable_scope(scope): | |||
seg_depth = N_SEG_CLASSES | |||
if cross_links: | |||
lnk_depth = N_LNK_CLASSES * (N_LOCAL_LINKS + N_CROSS_LINKS) | |||
else: | |||
lnk_depth = N_LNK_CLASSES * N_LOCAL_LINKS | |||
reg_depth = OFFSET_DIM | |||
map_depth = maps.get_shape()[3] | |||
inter_maps, inter_relu = ops.conv2d( | |||
maps, map_depth, 256, 1, 1, 'SAME', scope='conv_inter') | |||
dir_maps, dir_relu = ops.conv2d( | |||
inter_relu, 256, 2, ksize, 1, 'SAME', scope='conv_dir') | |||
cen_maps, cen_relu = ops.conv2d( | |||
inter_relu, 256, 2, ksize, 1, 'SAME', scope='conv_cen') | |||
pol_maps, pol_relu = ops.conv2d( | |||
inter_relu, 256, 8, ksize, 1, 'SAME', scope='conv_pol') | |||
concat_relu = tf.concat([dir_relu, cen_relu, pol_relu], axis=-1) | |||
_, lnk_embedding = ops.conv_relu( | |||
concat_relu, 12, 256, 1, 1, scope='lnk_embedding') | |||
lnk_maps, lnk_relu = ops.conv2d( | |||
inter_relu + lnk_embedding, | |||
256, | |||
lnk_depth, | |||
ksize, | |||
1, | |||
'SAME', | |||
scope='conv_lnk') | |||
char_seg_maps, char_seg_relu = ops.conv2d( | |||
inter_relu, | |||
256, | |||
seg_depth, | |||
ksize, | |||
1, | |||
'SAME', | |||
scope='conv_char_cls') | |||
char_reg_maps, char_reg_relu = ops.conv2d( | |||
inter_relu, | |||
256, | |||
reg_depth, | |||
ksize, | |||
1, | |||
'SAME', | |||
scope='conv_char_reg') | |||
concat_char_relu = tf.concat([char_seg_relu, char_reg_relu], | |||
axis=-1) | |||
_, char_embedding = ops.conv_relu( | |||
concat_char_relu, 8, 256, 1, 1, scope='conv_char_embedding') | |||
seg_maps, seg_relu = ops.conv2d( | |||
inter_relu + char_embedding, | |||
256, | |||
seg_depth, | |||
ksize, | |||
1, | |||
'SAME', | |||
scope='conv_cls') | |||
reg_maps, reg_relu = ops.conv2d( | |||
inter_relu + char_embedding, | |||
256, | |||
reg_depth, | |||
ksize, | |||
1, | |||
'SAME', | |||
scope='conv_reg') | |||
return seg_relu, lnk_relu, reg_relu | |||
def _build_cnn(self, images, weight_decay, is_training): | |||
with slim.arg_scope( | |||
resnet18_v1.resnet_arg_scope(weight_decay=weight_decay)): | |||
logits, end_points = resnet18_v1.resnet_v1_18( | |||
images, is_training=is_training, scope='resnet_v1_18') | |||
outputs = { | |||
'conv3_3': end_points['pool1'], | |||
'conv4_3': end_points['pool2'], | |||
'fc7': end_points['pool3'], | |||
'conv8_2': end_points['pool4'], | |||
'conv9_2': end_points['pool5'], | |||
'conv10_2': end_points['pool6'], | |||
} | |||
return outputs | |||
def build_model(self, images, is_training=True, scope=None): | |||
weight_decay = 5e-4 # FLAGS.weight_decay | |||
cnn_outputs = self._build_cnn(images, weight_decay, is_training) | |||
det_0 = self._detection_classifier( | |||
cnn_outputs['conv3_3'], | |||
3, | |||
weight_decay, | |||
cross_links=False, | |||
scope='dete_0') | |||
det_1 = self._detection_classifier( | |||
cnn_outputs['conv4_3'], | |||
3, | |||
weight_decay, | |||
cross_links=True, | |||
scope='dete_1') | |||
det_2 = self._detection_classifier( | |||
cnn_outputs['fc7'], | |||
3, | |||
weight_decay, | |||
cross_links=True, | |||
scope='dete_2') | |||
det_3 = self._detection_classifier( | |||
cnn_outputs['conv8_2'], | |||
3, | |||
weight_decay, | |||
cross_links=True, | |||
scope='dete_3') | |||
det_4 = self._detection_classifier( | |||
cnn_outputs['conv9_2'], | |||
3, | |||
weight_decay, | |||
cross_links=True, | |||
scope='dete_4') | |||
det_5 = self._detection_classifier( | |||
cnn_outputs['conv10_2'], | |||
3, | |||
weight_decay, | |||
cross_links=True, | |||
scope='dete_5') | |||
outputs = [det_0, det_1, det_2, det_3, det_4, det_5] | |||
return outputs |
@@ -0,0 +1,432 @@ | |||
"""Contains definitions for the original form of Residual Networks. | |||
The 'v1' residual networks (ResNets) implemented in this module were proposed | |||
by: | |||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | |||
Deep Residual Learning for Image Recognition. arXiv:1512.03385 | |||
Other variants were introduced in: | |||
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | |||
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 | |||
The networks defined in this module utilize the bottleneck building block of | |||
[1] with projection shortcuts only for increasing depths. They employ batch | |||
normalization *after* every weight layer. This is the architecture used by | |||
MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and | |||
ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' | |||
architecture and the alternative 'v2' architecture of [2] which uses batch | |||
normalization *before* every weight layer in the so-called full pre-activation | |||
units. | |||
Typical use: | |||
from tensorflow.contrib.slim.nets import resnet_v1 | |||
ResNet-101 for image classification into 1000 classes: | |||
# inputs has shape [batch, 224, 224, 3] | |||
with slim.arg_scope(resnet_v1.resnet_arg_scope()): | |||
net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) | |||
ResNet-101 for semantic segmentation into 21 classes: | |||
# inputs has shape [batch, 513, 513, 3] | |||
with slim.arg_scope(resnet_v1.resnet_arg_scope()): | |||
net, end_points = resnet_v1.resnet_v1_101(inputs, | |||
21, | |||
is_training=False, | |||
global_pool=False, | |||
output_stride=16) | |||
""" | |||
import tensorflow as tf | |||
import tf_slim as slim | |||
from . import resnet_utils | |||
if tf.__version__ >= '2.0': | |||
tf = tf.compat.v1 | |||
resnet_arg_scope = resnet_utils.resnet_arg_scope | |||
@slim.add_arg_scope | |||
def basicblock(inputs, | |||
depth, | |||
depth_bottleneck, | |||
stride, | |||
rate=1, | |||
outputs_collections=None, | |||
scope=None): | |||
"""Bottleneck residual unit variant with BN after convolutions. | |||
This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for | |||
its definition. Note that we use here the bottleneck variant which has an | |||
extra bottleneck layer. | |||
When putting together two consecutive ResNet blocks that use this unit, one | |||
should use stride = 2 in the last unit of the first block. | |||
Args: | |||
inputs: A tensor of size [batch, height, width, channels]. | |||
depth: The depth of the ResNet unit output. | |||
depth_bottleneck: The depth of the bottleneck layers. | |||
stride: The ResNet unit's stride. Determines the amount of downsampling of | |||
the units output compared to its input. | |||
rate: An integer, rate for atrous convolution. | |||
outputs_collections: Collection to add the ResNet unit output. | |||
scope: Optional variable_scope. | |||
Returns: | |||
The ResNet unit's output. | |||
""" | |||
with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: | |||
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) | |||
if depth == depth_in: | |||
shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') | |||
else: | |||
shortcut = slim.conv2d( | |||
inputs, | |||
depth, [1, 1], | |||
stride=stride, | |||
activation_fn=None, | |||
scope='shortcut') | |||
residual = resnet_utils.conv2d_same( | |||
inputs, depth, 3, stride, rate=rate, scope='conv1') | |||
residual = resnet_utils.conv2d_same( | |||
residual, depth, 3, 1, rate=rate, scope='conv2') | |||
output = tf.nn.relu(residual + shortcut) | |||
return slim.utils.collect_named_outputs(outputs_collections, | |||
sc.original_name_scope, output) | |||
@slim.add_arg_scope | |||
def bottleneck(inputs, | |||
depth, | |||
depth_bottleneck, | |||
stride, | |||
rate=1, | |||
outputs_collections=None, | |||
scope=None): | |||
"""Bottleneck residual unit variant with BN after convolutions. | |||
This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for | |||
its definition. Note that we use here the bottleneck variant which has an | |||
extra bottleneck layer. | |||
When putting together two consecutive ResNet blocks that use this unit, one | |||
should use stride = 2 in the last unit of the first block. | |||
Args: | |||
inputs: A tensor of size [batch, height, width, channels]. | |||
depth: The depth of the ResNet unit output. | |||
depth_bottleneck: The depth of the bottleneck layers. | |||
stride: The ResNet unit's stride. Determines the amount of downsampling of | |||
the units output compared to its input. | |||
rate: An integer, rate for atrous convolution. | |||
outputs_collections: Collection to add the ResNet unit output. | |||
scope: Optional variable_scope. | |||
Returns: | |||
The ResNet unit's output. | |||
""" | |||
with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: | |||
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) | |||
if depth == depth_in: | |||
shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') | |||
else: | |||
shortcut = slim.conv2d( | |||
inputs, | |||
depth, [1, 1], | |||
stride=stride, | |||
activation_fn=None, | |||
scope='shortcut') | |||
residual = slim.conv2d( | |||
inputs, depth_bottleneck, [1, 1], stride=1, scope='conv1') | |||
residual = resnet_utils.conv2d_same( | |||
residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2') | |||
residual = slim.conv2d( | |||
residual, | |||
depth, [1, 1], | |||
stride=1, | |||
activation_fn=None, | |||
scope='conv3') | |||
output = tf.nn.relu(shortcut + residual) | |||
return slim.utils.collect_named_outputs(outputs_collections, | |||
sc.original_name_scope, output) | |||
def resnet_v1(inputs, | |||
blocks, | |||
num_classes=None, | |||
is_training=True, | |||
global_pool=True, | |||
output_stride=None, | |||
include_root_block=True, | |||
spatial_squeeze=True, | |||
reuse=None, | |||
scope=None): | |||
"""Generator for v1 ResNet models. | |||
This function generates a family of ResNet v1 models. See the resnet_v1_*() | |||
methods for specific model instantiations, obtained by selecting different | |||
block instantiations that produce ResNets of various depths. | |||
Training for image classification on Imagenet is usually done with [224, 224] | |||
inputs, resulting in [7, 7] feature maps at the output of the last ResNet | |||
block for the ResNets defined in [1] that have nominal stride equal to 32. | |||
However, for dense prediction tasks we advise that one uses inputs with | |||
spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In | |||
this case the feature maps at the ResNet output will have spatial shape | |||
[(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] | |||
and corners exactly aligned with the input image corners, which greatly | |||
facilitates alignment of the features to the image. Using as input [225, 225] | |||
images results in [8, 8] feature maps at the output of the last ResNet block. | |||
For dense prediction tasks, the ResNet needs to run in fully-convolutional | |||
(FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all | |||
have nominal stride equal to 32 and a good choice in FCN mode is to use | |||
output_stride=16 in order to increase the density of the computed features at | |||
small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. | |||
Args: | |||
inputs: A tensor of size [batch, height_in, width_in, channels]. | |||
blocks: A list of length equal to the number of ResNet blocks. Each element | |||
is a resnet_utils.Block object describing the units in the block. | |||
num_classes: Number of predicted classes for classification tasks. If None | |||
we return the features before the logit layer. | |||
is_training: whether is training or not. | |||
global_pool: If True, we perform global average pooling before computing the | |||
logits. Set to True for image classification, False for dense prediction. | |||
output_stride: If None, then the output will be computed at the nominal | |||
network stride. If output_stride is not None, it specifies the requested | |||
ratio of input to output spatial resolution. | |||
include_root_block: If True, include the initial convolution followed by | |||
max-pooling, if False excludes it. | |||
spatial_squeeze: if True, logits is of shape [B, C], if false logits is | |||
of shape [B, 1, 1, C], where B is batch_size and C is number of classes. | |||
reuse: whether or not the network and its variables should be reused. To be | |||
able to reuse 'scope' must be given. | |||
scope: Optional variable_scope. | |||
Returns: | |||
net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. | |||
If global_pool is False, then height_out and width_out are reduced by a | |||
factor of output_stride compared to the respective height_in and width_in, | |||
else both height_out and width_out equal one. If num_classes is None, then | |||
net is the output of the last ResNet block, potentially after global | |||
average pooling. If num_classes is not None, net contains the pre-softmax | |||
activations. | |||
end_points: A dictionary from components of the network to the corresponding | |||
activation. | |||
Raises: | |||
ValueError: If the target output_stride is not valid. | |||
""" | |||
with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: | |||
end_points_collection = sc.name + '_end_points' | |||
with slim.arg_scope( | |||
[slim.conv2d, bottleneck, resnet_utils.stack_blocks_dense], | |||
outputs_collections=end_points_collection): | |||
with slim.arg_scope([slim.batch_norm], is_training=is_training): | |||
net = inputs | |||
if include_root_block: | |||
if output_stride is not None: | |||
if output_stride % 4 != 0: | |||
raise ValueError( | |||
'The output_stride needs to be a multiple of 4.' | |||
) | |||
output_stride /= 4 | |||
net = resnet_utils.conv2d_same( | |||
net, 64, 7, stride=2, scope='conv1') | |||
net = tf.pad(net, [[0, 0], [1, 1], [1, 1], [0, 0]]) | |||
net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') | |||
net = slim.utils.collect_named_outputs( | |||
end_points_collection, 'pool2', net) | |||
net = resnet_utils.stack_blocks_dense(net, blocks, | |||
output_stride) | |||
end_points = slim.utils.convert_collection_to_dict( | |||
end_points_collection) | |||
end_points['pool1'] = end_points['resnet_v1_18/block2/unit_2'] | |||
end_points['pool2'] = end_points['resnet_v1_18/block3/unit_2'] | |||
end_points['pool3'] = end_points['resnet_v1_18/block4/unit_2'] | |||
end_points['pool4'] = end_points['resnet_v1_18/block5/unit_2'] | |||
end_points['pool5'] = end_points['resnet_v1_18/block6/unit_2'] | |||
end_points['pool6'] = net | |||
return net, end_points | |||
resnet_v1.default_image_size = 224 | |||
def resnet_v1_18(inputs, | |||
num_classes=None, | |||
is_training=True, | |||
global_pool=True, | |||
output_stride=None, | |||
spatial_squeeze=True, | |||
reuse=None, | |||
scope='resnet_v1_18'): | |||
"""ResNet-18 model of [1]. See resnet_v1() for arg and return description.""" | |||
blocks = [ | |||
resnet_utils.Block('block1', basicblock, | |||
[(64, 64, 1)] + [(64, 64, 1)]), | |||
resnet_utils.Block('block2', basicblock, | |||
[(128, 128, 1)] + [(128, 128, 1)]), | |||
resnet_utils.Block('block3', basicblock, | |||
[(256, 256, 2)] + [(256, 256, 1)]), | |||
resnet_utils.Block('block4', basicblock, | |||
[(512, 512, 2)] + [(512, 512, 1)]), | |||
resnet_utils.Block('block5', basicblock, | |||
[(256, 256, 2)] + [(256, 256, 1)]), | |||
resnet_utils.Block('block6', basicblock, | |||
[(256, 256, 2)] + [(256, 256, 1)]), | |||
resnet_utils.Block('block7', basicblock, | |||
[(256, 256, 2)] + [(256, 256, 1)]), | |||
] | |||
return resnet_v1( | |||
inputs, | |||
blocks, | |||
num_classes, | |||
is_training, | |||
global_pool=global_pool, | |||
output_stride=output_stride, | |||
include_root_block=True, | |||
spatial_squeeze=spatial_squeeze, | |||
reuse=reuse, | |||
scope=scope) | |||
resnet_v1_18.default_image_size = resnet_v1.default_image_size | |||
def resnet_v1_50(inputs, | |||
num_classes=None, | |||
is_training=True, | |||
global_pool=True, | |||
output_stride=None, | |||
spatial_squeeze=True, | |||
reuse=None, | |||
scope='resnet_v1_50'): | |||
"""ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" | |||
blocks = [ | |||
resnet_utils.Block('block1', bottleneck, | |||
[(256, 64, 1)] * 2 + [(256, 64, 2)]), | |||
resnet_utils.Block('block2', bottleneck, | |||
[(512, 128, 1)] * 3 + [(512, 128, 2)]), | |||
resnet_utils.Block('block3', bottleneck, | |||
[(1024, 256, 1)] * 5 + [(1024, 256, 2)]), | |||
resnet_utils.Block('block4', bottleneck, | |||
[(2048, 512, 1)] * 3 + [(2048, 512, 2)]), | |||
resnet_utils.Block('block5', bottleneck, | |||
[(1024, 256, 1)] * 2 + [(1024, 256, 2)]), | |||
resnet_utils.Block('block6', bottleneck, [(1024, 256, 1)] * 2), | |||
] | |||
return resnet_v1( | |||
inputs, | |||
blocks, | |||
num_classes, | |||
is_training, | |||
global_pool=global_pool, | |||
output_stride=output_stride, | |||
include_root_block=True, | |||
spatial_squeeze=spatial_squeeze, | |||
reuse=reuse, | |||
scope=scope) | |||
resnet_v1_50.default_image_size = resnet_v1.default_image_size | |||
def resnet_v1_101(inputs, | |||
num_classes=None, | |||
is_training=True, | |||
global_pool=True, | |||
output_stride=None, | |||
spatial_squeeze=True, | |||
reuse=None, | |||
scope='resnet_v1_101'): | |||
"""ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" | |||
blocks = [ | |||
resnet_utils.Block('block1', bottleneck, | |||
[(256, 64, 1)] * 2 + [(256, 64, 2)]), | |||
resnet_utils.Block('block2', bottleneck, | |||
[(512, 128, 1)] * 3 + [(512, 128, 2)]), | |||
resnet_utils.Block('block3', bottleneck, | |||
[(1024, 256, 1)] * 22 + [(1024, 256, 2)]), | |||
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) | |||
] | |||
return resnet_v1( | |||
inputs, | |||
blocks, | |||
num_classes, | |||
is_training, | |||
global_pool=global_pool, | |||
output_stride=output_stride, | |||
include_root_block=True, | |||
spatial_squeeze=spatial_squeeze, | |||
reuse=reuse, | |||
scope=scope) | |||
resnet_v1_101.default_image_size = resnet_v1.default_image_size | |||
def resnet_v1_152(inputs, | |||
num_classes=None, | |||
is_training=True, | |||
global_pool=True, | |||
output_stride=None, | |||
spatial_squeeze=True, | |||
reuse=None, | |||
scope='resnet_v1_152'): | |||
"""ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" | |||
blocks = [ | |||
resnet_utils.Block('block1', bottleneck, | |||
[(256, 64, 1)] * 2 + [(256, 64, 2)]), | |||
resnet_utils.Block('block2', bottleneck, | |||
[(512, 128, 1)] * 7 + [(512, 128, 2)]), | |||
resnet_utils.Block('block3', bottleneck, | |||
[(1024, 256, 1)] * 35 + [(1024, 256, 2)]), | |||
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) | |||
] | |||
return resnet_v1( | |||
inputs, | |||
blocks, | |||
num_classes, | |||
is_training, | |||
global_pool=global_pool, | |||
output_stride=output_stride, | |||
include_root_block=True, | |||
spatial_squeeze=spatial_squeeze, | |||
reuse=reuse, | |||
scope=scope) | |||
resnet_v1_152.default_image_size = resnet_v1.default_image_size | |||
def resnet_v1_200(inputs, | |||
num_classes=None, | |||
is_training=True, | |||
global_pool=True, | |||
output_stride=None, | |||
spatial_squeeze=True, | |||
reuse=None, | |||
scope='resnet_v1_200'): | |||
"""ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" | |||
blocks = [ | |||
resnet_utils.Block('block1', bottleneck, | |||
[(256, 64, 1)] * 2 + [(256, 64, 2)]), | |||
resnet_utils.Block('block2', bottleneck, | |||
[(512, 128, 1)] * 23 + [(512, 128, 2)]), | |||
resnet_utils.Block('block3', bottleneck, | |||
[(1024, 256, 1)] * 35 + [(1024, 256, 2)]), | |||
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) | |||
] | |||
return resnet_v1( | |||
inputs, | |||
blocks, | |||
num_classes, | |||
is_training, | |||
global_pool=global_pool, | |||
output_stride=output_stride, | |||
include_root_block=True, | |||
spatial_squeeze=spatial_squeeze, | |||
reuse=reuse, | |||
scope=scope) | |||
resnet_v1_200.default_image_size = resnet_v1.default_image_size | |||
if __name__ == '__main__': | |||
input = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input') | |||
with slim.arg_scope(resnet_arg_scope()) as sc: | |||
logits = resnet_v1_50(input) |
@@ -0,0 +1,231 @@ | |||
"""Contains building blocks for various versions of Residual Networks. | |||
Residual networks (ResNets) were proposed in: | |||
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | |||
Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 | |||
More variants were introduced in: | |||
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | |||
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 | |||
We can obtain different ResNet variants by changing the network depth, width, | |||
and form of residual unit. This module implements the infrastructure for | |||
building them. Concrete ResNet units and full ResNet networks are implemented in | |||
the accompanying resnet_v1.py and resnet_v2.py modules. | |||
Compared to https://github.com/KaimingHe/deep-residual-networks, in the current | |||
implementation we subsample the output activations in the last residual unit of | |||
each block, instead of subsampling the input activations in the first residual | |||
unit of each block. The two implementations give identical results but our | |||
implementation is more memory efficient. | |||
""" | |||
import collections | |||
import tensorflow as tf | |||
import tf_slim as slim | |||
if tf.__version__ >= '2.0': | |||
tf = tf.compat.v1 | |||
class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): | |||
"""A named tuple describing a ResNet block. | |||
Its parts are: | |||
scope: The scope of the `Block`. | |||
unit_fn: The ResNet unit function which takes as input a `Tensor` and | |||
returns another `Tensor` with the output of the ResNet unit. | |||
args: A list of length equal to the number of units in the `Block`. The list | |||
contains one (depth, depth_bottleneck, stride) tuple for each unit in the | |||
block to serve as argument to unit_fn. | |||
""" | |||
def subsample(inputs, factor, scope=None): | |||
"""Subsamples the input along the spatial dimensions. | |||
Args: | |||
inputs: A `Tensor` of size [batch, height_in, width_in, channels]. | |||
factor: The subsampling factor. | |||
scope: Optional variable_scope. | |||
Returns: | |||
output: A `Tensor` of size [batch, height_out, width_out, channels] with the | |||
input, either intact (if factor == 1) or subsampled (if factor > 1). | |||
""" | |||
if factor == 1: | |||
return inputs | |||
else: | |||
return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) | |||
def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): | |||
"""Strided 2-D convolution with 'SAME' padding. | |||
When stride > 1, then we do explicit zero-padding, followed by conv2d with | |||
'VALID' padding. | |||
Note that | |||
net = conv2d_same(inputs, num_outputs, 3, stride=stride) | |||
is equivalent to | |||
net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') | |||
net = subsample(net, factor=stride) | |||
whereas | |||
net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') | |||
is different when the input's height or width is even, which is why we add the | |||
current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). | |||
Args: | |||
inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. | |||
num_outputs: An integer, the number of output filters. | |||
kernel_size: An int with the kernel_size of the filters. | |||
stride: An integer, the output stride. | |||
rate: An integer, rate for atrous convolution. | |||
scope: Scope. | |||
Returns: | |||
output: A 4-D tensor of size [batch, height_out, width_out, channels] with | |||
the convolution output. | |||
""" | |||
if stride == 1: | |||
return slim.conv2d( | |||
inputs, | |||
num_outputs, | |||
kernel_size, | |||
stride=1, | |||
rate=rate, | |||
padding='SAME', | |||
scope=scope) | |||
else: | |||
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) | |||
pad_total = kernel_size_effective - 1 | |||
pad_beg = pad_total // 2 | |||
pad_end = pad_total - pad_beg | |||
inputs = tf.pad( | |||
inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) | |||
return slim.conv2d( | |||
inputs, | |||
num_outputs, | |||
kernel_size, | |||
stride=stride, | |||
rate=rate, | |||
padding='VALID', | |||
scope=scope) | |||
@slim.add_arg_scope | |||
def stack_blocks_dense(net, | |||
blocks, | |||
output_stride=None, | |||
outputs_collections=None): | |||
"""Stacks ResNet `Blocks` and controls output feature density. | |||
First, this function creates scopes for the ResNet in the form of | |||
'block_name/unit_1', 'block_name/unit_2', etc. | |||
Second, this function allows the user to explicitly control the ResNet | |||
output_stride, which is the ratio of the input to output spatial resolution. | |||
This is useful for dense prediction tasks such as semantic segmentation or | |||
object detection. | |||
Most ResNets consist of 4 ResNet blocks and subsample the activations by a | |||
factor of 2 when transitioning between consecutive ResNet blocks. This results | |||
to a nominal ResNet output_stride equal to 8. If we set the output_stride to | |||
half the nominal network stride (e.g., output_stride=4), then we compute | |||
responses twice. | |||
Control of the output feature density is implemented by atrous convolution. | |||
Args: | |||
net: A `Tensor` of size [batch, height, width, channels]. | |||
blocks: A list of length equal to the number of ResNet `Blocks`. Each | |||
element is a ResNet `Block` object describing the units in the `Block`. | |||
output_stride: If `None`, then the output will be computed at the nominal | |||
network stride. If output_stride is not `None`, it specifies the requested | |||
ratio of input to output spatial resolution, which needs to be equal to | |||
the product of unit strides from the start up to some level of the ResNet. | |||
For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, | |||
then valid values for the output_stride are 1, 2, 6, 24 or None (which | |||
is equivalent to output_stride=24). | |||
outputs_collections: Collection to add the ResNet block outputs. | |||
Returns: | |||
net: Output tensor with stride equal to the specified output_stride. | |||
Raises: | |||
ValueError: If the target output_stride is not valid. | |||
""" | |||
# The current_stride variable keeps track of the effective stride of the | |||
# activations. This allows us to invoke atrous convolution whenever applying | |||
# the next residual unit would result in the activations having stride larger | |||
# than the target output_stride. | |||
current_stride = 1 | |||
# The atrous convolution rate parameter. | |||
rate = 1 | |||
for block in blocks: | |||
with tf.variable_scope(block.scope, 'block', [net]): | |||
for i, unit in enumerate(block.args): | |||
if output_stride is not None and current_stride > output_stride: | |||
raise ValueError( | |||
'The target output_stride cannot be reached.') | |||
with tf.variable_scope( | |||
'unit_%d' % (i + 1), values=[net]) as sc: | |||
unit_depth, unit_depth_bottleneck, unit_stride = unit | |||
# If we have reached the target output_stride, then we need to employ | |||
# atrous convolution with stride=1 and multiply the atrous rate by the | |||
# current unit's stride for use in subsequent layers. | |||
if output_stride is not None and current_stride == output_stride: | |||
net = block.unit_fn( | |||
net, | |||
depth=unit_depth, | |||
depth_bottleneck=unit_depth_bottleneck, | |||
stride=1, | |||
rate=rate) | |||
rate *= unit_stride | |||
else: | |||
net = block.unit_fn( | |||
net, | |||
depth=unit_depth, | |||
depth_bottleneck=unit_depth_bottleneck, | |||
stride=unit_stride, | |||
rate=1) | |||
current_stride *= unit_stride | |||
net = slim.utils.collect_named_outputs( | |||
outputs_collections, sc.name, net) | |||
if output_stride is not None and current_stride != output_stride: | |||
raise ValueError('The target output_stride cannot be reached.') | |||
return net | |||
def resnet_arg_scope(weight_decay=0.0001, | |||
batch_norm_decay=0.997, | |||
batch_norm_epsilon=1e-5, | |||
batch_norm_scale=True): | |||
"""Defines the default ResNet arg scope. | |||
TODO(gpapan): The batch-normalization related default values above are | |||
appropriate for use in conjunction with the reference ResNet models | |||
released at https://github.com/KaimingHe/deep-residual-networks. When | |||
training ResNets from scratch, they might need to be tuned. | |||
Args: | |||
weight_decay: The weight decay to use for regularizing the model. | |||
batch_norm_decay: The moving average decay when estimating layer activation | |||
statistics in batch normalization. | |||
batch_norm_epsilon: Small constant to prevent division by zero when | |||
normalizing activations by their variance in batch normalization. | |||
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the | |||
activations in the batch normalization layer. | |||
Returns: | |||
An `arg_scope` to use for the resnet models. | |||
""" | |||
batch_norm_params = { | |||
'decay': batch_norm_decay, | |||
'epsilon': batch_norm_epsilon, | |||
'scale': batch_norm_scale, | |||
'updates_collections': tf.GraphKeys.UPDATE_OPS, | |||
} | |||
with slim.arg_scope( | |||
[slim.conv2d], | |||
weights_regularizer=slim.l2_regularizer(weight_decay), | |||
weights_initializer=slim.variance_scaling_initializer(), | |||
activation_fn=tf.nn.relu, | |||
normalizer_fn=slim.batch_norm, | |||
normalizer_params=batch_norm_params): | |||
with slim.arg_scope([slim.batch_norm], **batch_norm_params): | |||
# The following implies padding='SAME' for pool1, which makes feature | |||
# alignment easier for dense prediction tasks. This is also used in | |||
# https://github.com/facebook/fb.resnet.torch. However the accompanying | |||
# code of 'Deep Residual Learning for Image Recognition' uses | |||
# padding='VALID' for pool1. You can switch to that choice by setting | |||
# slim.arg_scope([slim.max_pool2d], padding='VALID'). | |||
with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: | |||
return arg_sc |
@@ -0,0 +1,108 @@ | |||
import cv2 | |||
import numpy as np | |||
def rboxes_to_polygons(rboxes): | |||
""" | |||
Convert rboxes to polygons | |||
ARGS | |||
`rboxes`: [n, 5] | |||
RETURN | |||
`polygons`: [n, 8] | |||
""" | |||
theta = rboxes[:, 4:5] | |||
cxcy = rboxes[:, :2] | |||
half_w = rboxes[:, 2:3] / 2. | |||
half_h = rboxes[:, 3:4] / 2. | |||
v1 = np.hstack([np.cos(theta) * half_w, np.sin(theta) * half_w]) | |||
v2 = np.hstack([-np.sin(theta) * half_h, np.cos(theta) * half_h]) | |||
p1 = cxcy - v1 - v2 | |||
p2 = cxcy + v1 - v2 | |||
p3 = cxcy + v1 + v2 | |||
p4 = cxcy - v1 + v2 | |||
polygons = np.hstack([p1, p2, p3, p4]) | |||
return polygons | |||
def cal_width(box): | |||
pd1 = point_dist(box[0], box[1], box[2], box[3]) | |||
pd2 = point_dist(box[4], box[5], box[6], box[7]) | |||
return (pd1 + pd2) / 2 | |||
def point_dist(x1, y1, x2, y2): | |||
return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) | |||
def draw_polygons(img, polygons): | |||
for p in polygons.tolist(): | |||
p = [int(o) for o in p] | |||
cv2.line(img, (p[0], p[1]), (p[2], p[3]), (0, 255, 0), 1) | |||
cv2.line(img, (p[2], p[3]), (p[4], p[5]), (0, 255, 0), 1) | |||
cv2.line(img, (p[4], p[5]), (p[6], p[7]), (0, 255, 0), 1) | |||
cv2.line(img, (p[6], p[7]), (p[0], p[1]), (0, 255, 0), 1) | |||
return img | |||
def nms_python(boxes): | |||
boxes = sorted(boxes, key=lambda x: -x[8]) | |||
nms_flag = [True] * len(boxes) | |||
for i, a in enumerate(boxes): | |||
if not nms_flag[i]: | |||
continue | |||
else: | |||
for j, b in enumerate(boxes): | |||
if not j > i: | |||
continue | |||
if not nms_flag[j]: | |||
continue | |||
score_a = a[8] | |||
score_b = b[8] | |||
rbox_a = polygon2rbox(a[:8]) | |||
rbox_b = polygon2rbox(b[:8]) | |||
if point_in_rbox(rbox_a[:2], rbox_b) or point_in_rbox( | |||
rbox_b[:2], rbox_a): | |||
if score_a > score_b: | |||
nms_flag[j] = False | |||
boxes_nms = [] | |||
for i, box in enumerate(boxes): | |||
if nms_flag[i]: | |||
boxes_nms.append(box) | |||
return boxes_nms | |||
def point_in_rbox(c, rbox): | |||
cx0, cy0 = c[0], c[1] | |||
cx1, cy1 = rbox[0], rbox[1] | |||
w, h = rbox[2], rbox[3] | |||
theta = rbox[4] | |||
dist_x = np.abs((cx1 - cx0) * np.cos(theta) + (cy1 - cy0) * np.sin(theta)) | |||
dist_y = np.abs(-(cx1 - cx0) * np.sin(theta) + (cy1 - cy0) * np.cos(theta)) | |||
return ((dist_x < w / 2.0) and (dist_y < h / 2.0)) | |||
def polygon2rbox(polygon): | |||
x1, x2, x3, x4 = polygon[0], polygon[2], polygon[4], polygon[6] | |||
y1, y2, y3, y4 = polygon[1], polygon[3], polygon[5], polygon[7] | |||
c_x = (x1 + x2 + x3 + x4) / 4 | |||
c_y = (y1 + y2 + y3 + y4) / 4 | |||
w1 = point_dist(x1, y1, x2, y2) | |||
w2 = point_dist(x3, y3, x4, y4) | |||
h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2) | |||
h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4) | |||
h = h1 + h2 | |||
w = (w1 + w2) / 2 | |||
theta1 = np.arctan2(y2 - y1, x2 - x1) | |||
theta2 = np.arctan2(y3 - y4, x3 - x4) | |||
theta = (theta1 + theta2) / 2.0 | |||
return [c_x, c_y, w, h, theta] | |||
def point_line_dist(px, py, x1, y1, x2, y2): | |||
eps = 1e-6 | |||
dx = x2 - x1 | |||
dy = y2 - y1 | |||
div = np.sqrt(dx * dx + dy * dy) + eps | |||
dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div | |||
return dist |
@@ -54,6 +54,13 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.pose_estimation: ['poses', 'boxes'], | |||
# ocr detection result for single sample | |||
# { | |||
# "det_polygons": np.array with shape [num_text, 8], each box is | |||
# [x1, y1, x2, y2, x3, y3, x4, y4] | |||
# } | |||
Tasks.ocr_detection: ['det_polygons'], | |||
# ============ nlp tasks =================== | |||
# text classification result for single sample | |||
@@ -28,6 +28,7 @@ class Tasks(object): | |||
image_editing = 'image-editing' | |||
image_generation = 'image-generation' | |||
image_matting = 'image-matting' | |||
ocr_detection = 'ocr-detection' | |||
# nlp tasks | |||
word_segmentation = 'word-segmentation' | |||
@@ -1 +1,2 @@ | |||
easydict | |||
tf_slim |
@@ -0,0 +1,37 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import shutil | |||
import sys | |||
import tempfile | |||
import unittest | |||
from typing import Any, Dict, List, Tuple, Union | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class OCRDetectionTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_resnet18_ocr-detection-line-level_damo' | |||
self.test_image = 'data/test/images/ocr_detection.jpg' | |||
def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||
result = pipeline(input_location) | |||
print('ocr detection results: ') | |||
print(result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
ocr_detection = pipeline(Tasks.ocr_detection) | |||
self.pipeline_inference(ocr_detection, self.test_image) | |||
if __name__ == '__main__': | |||
unittest.main() |