Browse Source

[to #42322933] create ocr_detection task

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8967432

* create ocr_detection task

* replace c++ nms with python version

* replace c++ decoder with python version

* add requirements for ocr_detection
master
xixing.tj wenmeng.zwm 3 years ago
parent
commit
968e660235
14 changed files with 2246 additions and 0 deletions
  1. +3
    -0
      data/test/images/ocr_detection.jpg
  2. +2
    -0
      modelscope/pipelines/builder.py
  3. +1
    -0
      modelscope/pipelines/cv/__init__.py
  4. +167
    -0
      modelscope/pipelines/cv/ocr_detection_pipeline.py
  5. +0
    -0
      modelscope/pipelines/cv/ocr_utils/__init__.py
  6. +158
    -0
      modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py
  7. +1098
    -0
      modelscope/pipelines/cv/ocr_utils/ops.py
  8. +432
    -0
      modelscope/pipelines/cv/ocr_utils/resnet18_v1.py
  9. +231
    -0
      modelscope/pipelines/cv/ocr_utils/resnet_utils.py
  10. +108
    -0
      modelscope/pipelines/cv/ocr_utils/utils.py
  11. +7
    -0
      modelscope/pipelines/outputs.py
  12. +1
    -0
      modelscope/utils/constant.py
  13. +1
    -0
      requirements/cv.txt
  14. +37
    -0
      tests/pipelines/test_ocr_detection.py

+ 3
- 0
data/test/images/ocr_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5c8435db5583400be5d11a2c17910c96133b462c8a99ccaf0e19f4aac34e0a94
size 141149

+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -28,6 +28,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_generation:
('person-image-cartoon',
'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: ('ocr-detection',
'damo/cv_resnet18_ocr-detection-line-level_damo'),
}




+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -1,2 +1,3 @@
from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_matting_pipeline import ImageMattingPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline

+ 167
- 0
modelscope/pipelines/cv/ocr_detection_pipeline.py View File

@@ -0,0 +1,167 @@
import math
import os
import os.path as osp
import sys
from typing import Any, Dict, List, Tuple, Union

import cv2
import numpy as np
import PIL
import tensorflow as tf
import tf_slim as slim

from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils

if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.compat.v1.disable_eager_execution()

logger = get_logger()

# constant
RBOX_DIM = 5
OFFSET_DIM = 6
WORD_POLYGON_DIM = 8
OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('node_threshold', 0.4,
'Confidence threshold for nodes')
tf.app.flags.DEFINE_float('link_threshold', 0.6,
'Confidence threshold for links')


@PIPELINES.register_module(
Tasks.ocr_detection, module_name=Tasks.ocr_detection)
class OCRDetectionPipeline(Pipeline):

def __init__(self, model: str):
super().__init__(model=model)
model_path = osp.join(
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER),
'checkpoint-80000')

config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
self._session = tf.Session(config=config)
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0),
dtype=tf.int64,
trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)
self.input_images = tf.placeholder(
tf.float32, shape=[1, 1024, 1024, 3], name='input_images')
self.output = {}

# detector
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector()
all_maps = detector.build_model(self.input_images, is_training=False)

# decode local predictions
all_nodes, all_links, all_reg = [], [], []
for i, maps in enumerate(all_maps):
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2]
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)

cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))

lnk_prob_pos = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, :2])
lnk_prob_mut = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, 2:])
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1)

all_nodes.append(cls_prob)
all_links.append(lnk_prob)
all_reg.append(reg_maps)

# decode segments and links
image_size = tf.shape(self.input_images)[1:3]
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python(
image_size,
all_nodes,
all_links,
all_reg,
anchor_sizes=list(detector.anchor_sizes))

# combine segments
combined_rboxes, combined_counts = ops.combine_segments_python(
segments, group_indices, segment_counts)
self.output['combined_rboxes'] = combined_rboxes
self.output['combined_counts'] = combined_counts

with self._session.as_default() as sess:
logger.info(f'loading model from {model_path}')
# load model
model_loader = tf.train.Saver(
variable_averages.variables_to_restore())
model_loader.restore(sess, model_path)

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
h, w, c = img.shape
img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32)
img_pad[:h, :w, :] = img

resize_size = 1024
img_pad_resize = cv2.resize(img_pad, (resize_size, resize_size))
img_pad_resize = cv2.cvtColor(img_pad_resize, cv2.COLOR_RGB2BGR)
img_pad_resize = img_pad_resize - np.array([123.68, 116.78, 103.94],
dtype=np.float32)

resize_size = tf.stack([resize_size, resize_size])
orig_size = tf.stack([max(h, w), max(h, w)])
self.output['orig_size'] = orig_size
self.output['resize_size'] = resize_size

result = {'img': np.expand_dims(img_pad_resize, axis=0)}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
with self._session.as_default():
feed_dict = {self.input_images: input['img']}
sess_outputs = self._session.run(self.output, feed_dict=feed_dict)
return sess_outputs

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
rboxes = inputs['combined_rboxes'][0]
count = inputs['combined_counts'][0]
rboxes = rboxes[:count, :]

# convert rboxes to polygons and find its coordinates on the original image
orig_h, orig_w = inputs['orig_size']
resize_h, resize_w = inputs['resize_size']
polygons = utils.rboxes_to_polygons(rboxes)
scale_y = float(orig_h) / float(resize_h)
scale_x = float(orig_w) / float(resize_w)

# confine polygons inside image
polygons[:, ::2] = np.maximum(
0, np.minimum(polygons[:, ::2] * scale_x, orig_w - 1))
polygons[:, 1::2] = np.maximum(
0, np.minimum(polygons[:, 1::2] * scale_y, orig_h - 1))
polygons = np.round(polygons).astype(np.int32)

# nms
dt_n9 = [o + [utils.cal_width(o)] for o in polygons.tolist()]
dt_nms = utils.nms_python(dt_n9)
dt_polygons = np.array([o[:8] for o in dt_nms])

result = {'det_polygons': dt_polygons}
return result

+ 0
- 0
modelscope/pipelines/cv/ocr_utils/__init__.py View File


+ 158
- 0
modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py View File

@@ -0,0 +1,158 @@
import tensorflow as tf
import tf_slim as slim

from . import ops, resnet18_v1, resnet_utils

if tf.__version__ >= '2.0':
tf = tf.compat.v1

# constants
OFFSET_DIM = 6

N_LOCAL_LINKS = 8
N_CROSS_LINKS = 4
N_SEG_CLASSES = 2
N_LNK_CLASSES = 4

POS_LABEL = 1
NEG_LABEL = 0


class SegLinkDetector():

def __init__(self):
self.anchor_sizes = [6., 11.84210526, 23.68421053, 45., 90., 150.]

def _detection_classifier(self,
maps,
ksize,
weight_decay,
cross_links=False,
scope=None):

with tf.variable_scope(scope):
seg_depth = N_SEG_CLASSES
if cross_links:
lnk_depth = N_LNK_CLASSES * (N_LOCAL_LINKS + N_CROSS_LINKS)
else:
lnk_depth = N_LNK_CLASSES * N_LOCAL_LINKS
reg_depth = OFFSET_DIM
map_depth = maps.get_shape()[3]
inter_maps, inter_relu = ops.conv2d(
maps, map_depth, 256, 1, 1, 'SAME', scope='conv_inter')

dir_maps, dir_relu = ops.conv2d(
inter_relu, 256, 2, ksize, 1, 'SAME', scope='conv_dir')
cen_maps, cen_relu = ops.conv2d(
inter_relu, 256, 2, ksize, 1, 'SAME', scope='conv_cen')
pol_maps, pol_relu = ops.conv2d(
inter_relu, 256, 8, ksize, 1, 'SAME', scope='conv_pol')
concat_relu = tf.concat([dir_relu, cen_relu, pol_relu], axis=-1)
_, lnk_embedding = ops.conv_relu(
concat_relu, 12, 256, 1, 1, scope='lnk_embedding')
lnk_maps, lnk_relu = ops.conv2d(
inter_relu + lnk_embedding,
256,
lnk_depth,
ksize,
1,
'SAME',
scope='conv_lnk')

char_seg_maps, char_seg_relu = ops.conv2d(
inter_relu,
256,
seg_depth,
ksize,
1,
'SAME',
scope='conv_char_cls')
char_reg_maps, char_reg_relu = ops.conv2d(
inter_relu,
256,
reg_depth,
ksize,
1,
'SAME',
scope='conv_char_reg')
concat_char_relu = tf.concat([char_seg_relu, char_reg_relu],
axis=-1)
_, char_embedding = ops.conv_relu(
concat_char_relu, 8, 256, 1, 1, scope='conv_char_embedding')
seg_maps, seg_relu = ops.conv2d(
inter_relu + char_embedding,
256,
seg_depth,
ksize,
1,
'SAME',
scope='conv_cls')
reg_maps, reg_relu = ops.conv2d(
inter_relu + char_embedding,
256,
reg_depth,
ksize,
1,
'SAME',
scope='conv_reg')

return seg_relu, lnk_relu, reg_relu

def _build_cnn(self, images, weight_decay, is_training):
with slim.arg_scope(
resnet18_v1.resnet_arg_scope(weight_decay=weight_decay)):
logits, end_points = resnet18_v1.resnet_v1_18(
images, is_training=is_training, scope='resnet_v1_18')

outputs = {
'conv3_3': end_points['pool1'],
'conv4_3': end_points['pool2'],
'fc7': end_points['pool3'],
'conv8_2': end_points['pool4'],
'conv9_2': end_points['pool5'],
'conv10_2': end_points['pool6'],
}
return outputs

def build_model(self, images, is_training=True, scope=None):

weight_decay = 5e-4 # FLAGS.weight_decay
cnn_outputs = self._build_cnn(images, weight_decay, is_training)
det_0 = self._detection_classifier(
cnn_outputs['conv3_3'],
3,
weight_decay,
cross_links=False,
scope='dete_0')
det_1 = self._detection_classifier(
cnn_outputs['conv4_3'],
3,
weight_decay,
cross_links=True,
scope='dete_1')
det_2 = self._detection_classifier(
cnn_outputs['fc7'],
3,
weight_decay,
cross_links=True,
scope='dete_2')
det_3 = self._detection_classifier(
cnn_outputs['conv8_2'],
3,
weight_decay,
cross_links=True,
scope='dete_3')
det_4 = self._detection_classifier(
cnn_outputs['conv9_2'],
3,
weight_decay,
cross_links=True,
scope='dete_4')
det_5 = self._detection_classifier(
cnn_outputs['conv10_2'],
3,
weight_decay,
cross_links=True,
scope='dete_5')
outputs = [det_0, det_1, det_2, det_3, det_4, det_5]
return outputs

+ 1098
- 0
modelscope/pipelines/cv/ocr_utils/ops.py
File diff suppressed because it is too large
View File


+ 432
- 0
modelscope/pipelines/cv/ocr_utils/resnet18_v1.py View File

@@ -0,0 +1,432 @@
"""Contains definitions for the original form of Residual Networks.
The 'v1' residual networks (ResNets) implemented in this module were proposed
by:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
Other variants were introduced in:
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
The networks defined in this module utilize the bottleneck building block of
[1] with projection shortcuts only for increasing depths. They employ batch
normalization *after* every weight layer. This is the architecture used by
MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and
ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1'
architecture and the alternative 'v2' architecture of [2] which uses batch
normalization *before* every weight layer in the so-called full pre-activation
units.
Typical use:
from tensorflow.contrib.slim.nets import resnet_v1
ResNet-101 for image classification into 1000 classes:
# inputs has shape [batch, 224, 224, 3]
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False)
ResNet-101 for semantic segmentation into 21 classes:
# inputs has shape [batch, 513, 513, 3]
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
net, end_points = resnet_v1.resnet_v1_101(inputs,
21,
is_training=False,
global_pool=False,
output_stride=16)
"""
import tensorflow as tf
import tf_slim as slim

from . import resnet_utils

if tf.__version__ >= '2.0':
tf = tf.compat.v1

resnet_arg_scope = resnet_utils.resnet_arg_scope


@slim.add_arg_scope
def basicblock(inputs,
depth,
depth_bottleneck,
stride,
rate=1,
outputs_collections=None,
scope=None):
"""Bottleneck residual unit variant with BN after convolutions.
This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
its definition. Note that we use here the bottleneck variant which has an
extra bottleneck layer.
When putting together two consecutive ResNet blocks that use this unit, one
should use stride = 2 in the last unit of the first block.
Args:
inputs: A tensor of size [batch, height, width, channels].
depth: The depth of the ResNet unit output.
depth_bottleneck: The depth of the bottleneck layers.
stride: The ResNet unit's stride. Determines the amount of downsampling of
the units output compared to its input.
rate: An integer, rate for atrous convolution.
outputs_collections: Collection to add the ResNet unit output.
scope: Optional variable_scope.
Returns:
The ResNet unit's output.
"""
with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
if depth == depth_in:
shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
else:
shortcut = slim.conv2d(
inputs,
depth, [1, 1],
stride=stride,
activation_fn=None,
scope='shortcut')

residual = resnet_utils.conv2d_same(
inputs, depth, 3, stride, rate=rate, scope='conv1')
residual = resnet_utils.conv2d_same(
residual, depth, 3, 1, rate=rate, scope='conv2')

output = tf.nn.relu(residual + shortcut)

return slim.utils.collect_named_outputs(outputs_collections,
sc.original_name_scope, output)


@slim.add_arg_scope
def bottleneck(inputs,
depth,
depth_bottleneck,
stride,
rate=1,
outputs_collections=None,
scope=None):
"""Bottleneck residual unit variant with BN after convolutions.
This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
its definition. Note that we use here the bottleneck variant which has an
extra bottleneck layer.
When putting together two consecutive ResNet blocks that use this unit, one
should use stride = 2 in the last unit of the first block.
Args:
inputs: A tensor of size [batch, height, width, channels].
depth: The depth of the ResNet unit output.
depth_bottleneck: The depth of the bottleneck layers.
stride: The ResNet unit's stride. Determines the amount of downsampling of
the units output compared to its input.
rate: An integer, rate for atrous convolution.
outputs_collections: Collection to add the ResNet unit output.
scope: Optional variable_scope.
Returns:
The ResNet unit's output.
"""
with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
if depth == depth_in:
shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
else:
shortcut = slim.conv2d(
inputs,
depth, [1, 1],
stride=stride,
activation_fn=None,
scope='shortcut')

residual = slim.conv2d(
inputs, depth_bottleneck, [1, 1], stride=1, scope='conv1')
residual = resnet_utils.conv2d_same(
residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2')
residual = slim.conv2d(
residual,
depth, [1, 1],
stride=1,
activation_fn=None,
scope='conv3')

output = tf.nn.relu(shortcut + residual)

return slim.utils.collect_named_outputs(outputs_collections,
sc.original_name_scope, output)


def resnet_v1(inputs,
blocks,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
include_root_block=True,
spatial_squeeze=True,
reuse=None,
scope=None):
"""Generator for v1 ResNet models.
This function generates a family of ResNet v1 models. See the resnet_v1_*()
methods for specific model instantiations, obtained by selecting different
block instantiations that produce ResNets of various depths.
Training for image classification on Imagenet is usually done with [224, 224]
inputs, resulting in [7, 7] feature maps at the output of the last ResNet
block for the ResNets defined in [1] that have nominal stride equal to 32.
However, for dense prediction tasks we advise that one uses inputs with
spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
this case the feature maps at the ResNet output will have spatial shape
[(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
and corners exactly aligned with the input image corners, which greatly
facilitates alignment of the features to the image. Using as input [225, 225]
images results in [8, 8] feature maps at the output of the last ResNet block.
For dense prediction tasks, the ResNet needs to run in fully-convolutional
(FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
have nominal stride equal to 32 and a good choice in FCN mode is to use
output_stride=16 in order to increase the density of the computed features at
small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
Args:
inputs: A tensor of size [batch, height_in, width_in, channels].
blocks: A list of length equal to the number of ResNet blocks. Each element
is a resnet_utils.Block object describing the units in the block.
num_classes: Number of predicted classes for classification tasks. If None
we return the features before the logit layer.
is_training: whether is training or not.
global_pool: If True, we perform global average pooling before computing the
logits. Set to True for image classification, False for dense prediction.
output_stride: If None, then the output will be computed at the nominal
network stride. If output_stride is not None, it specifies the requested
ratio of input to output spatial resolution.
include_root_block: If True, include the initial convolution followed by
max-pooling, if False excludes it.
spatial_squeeze: if True, logits is of shape [B, C], if false logits is
of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
Returns:
net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
If global_pool is False, then height_out and width_out are reduced by a
factor of output_stride compared to the respective height_in and width_in,
else both height_out and width_out equal one. If num_classes is None, then
net is the output of the last ResNet block, potentially after global
average pooling. If num_classes is not None, net contains the pre-softmax
activations.
end_points: A dictionary from components of the network to the corresponding
activation.
Raises:
ValueError: If the target output_stride is not valid.
"""
with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
end_points_collection = sc.name + '_end_points'
with slim.arg_scope(
[slim.conv2d, bottleneck, resnet_utils.stack_blocks_dense],
outputs_collections=end_points_collection):
with slim.arg_scope([slim.batch_norm], is_training=is_training):
net = inputs
if include_root_block:
if output_stride is not None:
if output_stride % 4 != 0:
raise ValueError(
'The output_stride needs to be a multiple of 4.'
)
output_stride /= 4
net = resnet_utils.conv2d_same(
net, 64, 7, stride=2, scope='conv1')
net = tf.pad(net, [[0, 0], [1, 1], [1, 1], [0, 0]])
net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')

net = slim.utils.collect_named_outputs(
end_points_collection, 'pool2', net)

net = resnet_utils.stack_blocks_dense(net, blocks,
output_stride)

end_points = slim.utils.convert_collection_to_dict(
end_points_collection)

end_points['pool1'] = end_points['resnet_v1_18/block2/unit_2']
end_points['pool2'] = end_points['resnet_v1_18/block3/unit_2']
end_points['pool3'] = end_points['resnet_v1_18/block4/unit_2']
end_points['pool4'] = end_points['resnet_v1_18/block5/unit_2']
end_points['pool5'] = end_points['resnet_v1_18/block6/unit_2']
end_points['pool6'] = net

return net, end_points


resnet_v1.default_image_size = 224


def resnet_v1_18(inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_18'):
"""ResNet-18 model of [1]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block('block1', basicblock,
[(64, 64, 1)] + [(64, 64, 1)]),
resnet_utils.Block('block2', basicblock,
[(128, 128, 1)] + [(128, 128, 1)]),
resnet_utils.Block('block3', basicblock,
[(256, 256, 2)] + [(256, 256, 1)]),
resnet_utils.Block('block4', basicblock,
[(512, 512, 2)] + [(512, 512, 1)]),
resnet_utils.Block('block5', basicblock,
[(256, 256, 2)] + [(256, 256, 1)]),
resnet_utils.Block('block6', basicblock,
[(256, 256, 2)] + [(256, 256, 1)]),
resnet_utils.Block('block7', basicblock,
[(256, 256, 2)] + [(256, 256, 1)]),
]
return resnet_v1(
inputs,
blocks,
num_classes,
is_training,
global_pool=global_pool,
output_stride=output_stride,
include_root_block=True,
spatial_squeeze=spatial_squeeze,
reuse=reuse,
scope=scope)


resnet_v1_18.default_image_size = resnet_v1.default_image_size


def resnet_v1_50(inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_50'):
"""ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block('block1', bottleneck,
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 3 + [(512, 128, 2)]),
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
resnet_utils.Block('block4', bottleneck,
[(2048, 512, 1)] * 3 + [(2048, 512, 2)]),
resnet_utils.Block('block5', bottleneck,
[(1024, 256, 1)] * 2 + [(1024, 256, 2)]),
resnet_utils.Block('block6', bottleneck, [(1024, 256, 1)] * 2),
]
return resnet_v1(
inputs,
blocks,
num_classes,
is_training,
global_pool=global_pool,
output_stride=output_stride,
include_root_block=True,
spatial_squeeze=spatial_squeeze,
reuse=reuse,
scope=scope)


resnet_v1_50.default_image_size = resnet_v1.default_image_size


def resnet_v1_101(inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_101'):
"""ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block('block1', bottleneck,
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 3 + [(512, 128, 2)]),
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3)
]
return resnet_v1(
inputs,
blocks,
num_classes,
is_training,
global_pool=global_pool,
output_stride=output_stride,
include_root_block=True,
spatial_squeeze=spatial_squeeze,
reuse=reuse,
scope=scope)


resnet_v1_101.default_image_size = resnet_v1.default_image_size


def resnet_v1_152(inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_152'):
"""ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block('block1', bottleneck,
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 7 + [(512, 128, 2)]),
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3)
]
return resnet_v1(
inputs,
blocks,
num_classes,
is_training,
global_pool=global_pool,
output_stride=output_stride,
include_root_block=True,
spatial_squeeze=spatial_squeeze,
reuse=reuse,
scope=scope)


resnet_v1_152.default_image_size = resnet_v1.default_image_size


def resnet_v1_200(inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_200'):
"""ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
blocks = [
resnet_utils.Block('block1', bottleneck,
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 23 + [(512, 128, 2)]),
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3)
]
return resnet_v1(
inputs,
blocks,
num_classes,
is_training,
global_pool=global_pool,
output_stride=output_stride,
include_root_block=True,
spatial_squeeze=spatial_squeeze,
reuse=reuse,
scope=scope)


resnet_v1_200.default_image_size = resnet_v1.default_image_size

if __name__ == '__main__':
input = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input')
with slim.arg_scope(resnet_arg_scope()) as sc:
logits = resnet_v1_50(input)

+ 231
- 0
modelscope/pipelines/cv/ocr_utils/resnet_utils.py View File

@@ -0,0 +1,231 @@
"""Contains building blocks for various versions of Residual Networks.
Residual networks (ResNets) were proposed in:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
More variants were introduced in:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
We can obtain different ResNet variants by changing the network depth, width,
and form of residual unit. This module implements the infrastructure for
building them. Concrete ResNet units and full ResNet networks are implemented in
the accompanying resnet_v1.py and resnet_v2.py modules.
Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
implementation we subsample the output activations in the last residual unit of
each block, instead of subsampling the input activations in the first residual
unit of each block. The two implementations give identical results but our
implementation is more memory efficient.
"""

import collections

import tensorflow as tf
import tf_slim as slim

if tf.__version__ >= '2.0':
tf = tf.compat.v1


class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
"""A named tuple describing a ResNet block.
Its parts are:
scope: The scope of the `Block`.
unit_fn: The ResNet unit function which takes as input a `Tensor` and
returns another `Tensor` with the output of the ResNet unit.
args: A list of length equal to the number of units in the `Block`. The list
contains one (depth, depth_bottleneck, stride) tuple for each unit in the
block to serve as argument to unit_fn.
"""


def subsample(inputs, factor, scope=None):
"""Subsamples the input along the spatial dimensions.
Args:
inputs: A `Tensor` of size [batch, height_in, width_in, channels].
factor: The subsampling factor.
scope: Optional variable_scope.
Returns:
output: A `Tensor` of size [batch, height_out, width_out, channels] with the
input, either intact (if factor == 1) or subsampled (if factor > 1).
"""
if factor == 1:
return inputs
else:
return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)


def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
"""Strided 2-D convolution with 'SAME' padding.
When stride > 1, then we do explicit zero-padding, followed by conv2d with
'VALID' padding.
Note that
net = conv2d_same(inputs, num_outputs, 3, stride=stride)
is equivalent to
net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
net = subsample(net, factor=stride)
whereas
net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
is different when the input's height or width is even, which is why we add the
current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
Args:
inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
num_outputs: An integer, the number of output filters.
kernel_size: An int with the kernel_size of the filters.
stride: An integer, the output stride.
rate: An integer, rate for atrous convolution.
scope: Scope.
Returns:
output: A 4-D tensor of size [batch, height_out, width_out, channels] with
the convolution output.
"""
if stride == 1:
return slim.conv2d(
inputs,
num_outputs,
kernel_size,
stride=1,
rate=rate,
padding='SAME',
scope=scope)
else:
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
inputs = tf.pad(
inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
return slim.conv2d(
inputs,
num_outputs,
kernel_size,
stride=stride,
rate=rate,
padding='VALID',
scope=scope)


@slim.add_arg_scope
def stack_blocks_dense(net,
blocks,
output_stride=None,
outputs_collections=None):
"""Stacks ResNet `Blocks` and controls output feature density.
First, this function creates scopes for the ResNet in the form of
'block_name/unit_1', 'block_name/unit_2', etc.
Second, this function allows the user to explicitly control the ResNet
output_stride, which is the ratio of the input to output spatial resolution.
This is useful for dense prediction tasks such as semantic segmentation or
object detection.
Most ResNets consist of 4 ResNet blocks and subsample the activations by a
factor of 2 when transitioning between consecutive ResNet blocks. This results
to a nominal ResNet output_stride equal to 8. If we set the output_stride to
half the nominal network stride (e.g., output_stride=4), then we compute
responses twice.
Control of the output feature density is implemented by atrous convolution.
Args:
net: A `Tensor` of size [batch, height, width, channels].
blocks: A list of length equal to the number of ResNet `Blocks`. Each
element is a ResNet `Block` object describing the units in the `Block`.
output_stride: If `None`, then the output will be computed at the nominal
network stride. If output_stride is not `None`, it specifies the requested
ratio of input to output spatial resolution, which needs to be equal to
the product of unit strides from the start up to some level of the ResNet.
For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
then valid values for the output_stride are 1, 2, 6, 24 or None (which
is equivalent to output_stride=24).
outputs_collections: Collection to add the ResNet block outputs.
Returns:
net: Output tensor with stride equal to the specified output_stride.
Raises:
ValueError: If the target output_stride is not valid.
"""
# The current_stride variable keeps track of the effective stride of the
# activations. This allows us to invoke atrous convolution whenever applying
# the next residual unit would result in the activations having stride larger
# than the target output_stride.
current_stride = 1

# The atrous convolution rate parameter.
rate = 1

for block in blocks:
with tf.variable_scope(block.scope, 'block', [net]):
for i, unit in enumerate(block.args):
if output_stride is not None and current_stride > output_stride:
raise ValueError(
'The target output_stride cannot be reached.')

with tf.variable_scope(
'unit_%d' % (i + 1), values=[net]) as sc:
unit_depth, unit_depth_bottleneck, unit_stride = unit
# If we have reached the target output_stride, then we need to employ
# atrous convolution with stride=1 and multiply the atrous rate by the
# current unit's stride for use in subsequent layers.
if output_stride is not None and current_stride == output_stride:
net = block.unit_fn(
net,
depth=unit_depth,
depth_bottleneck=unit_depth_bottleneck,
stride=1,
rate=rate)
rate *= unit_stride

else:
net = block.unit_fn(
net,
depth=unit_depth,
depth_bottleneck=unit_depth_bottleneck,
stride=unit_stride,
rate=1)
current_stride *= unit_stride
net = slim.utils.collect_named_outputs(
outputs_collections, sc.name, net)

if output_stride is not None and current_stride != output_stride:
raise ValueError('The target output_stride cannot be reached.')

return net


def resnet_arg_scope(weight_decay=0.0001,
batch_norm_decay=0.997,
batch_norm_epsilon=1e-5,
batch_norm_scale=True):
"""Defines the default ResNet arg scope.
TODO(gpapan): The batch-normalization related default values above are
appropriate for use in conjunction with the reference ResNet models
released at https://github.com/KaimingHe/deep-residual-networks. When
training ResNets from scratch, they might need to be tuned.
Args:
weight_decay: The weight decay to use for regularizing the model.
batch_norm_decay: The moving average decay when estimating layer activation
statistics in batch normalization.
batch_norm_epsilon: Small constant to prevent division by zero when
normalizing activations by their variance in batch normalization.
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer.
Returns:
An `arg_scope` to use for the resnet models.
"""
batch_norm_params = {
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale,
'updates_collections': tf.GraphKeys.UPDATE_OPS,
}

with slim.arg_scope(
[slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
# The following implies padding='SAME' for pool1, which makes feature
# alignment easier for dense prediction tasks. This is also used in
# https://github.com/facebook/fb.resnet.torch. However the accompanying
# code of 'Deep Residual Learning for Image Recognition' uses
# padding='VALID' for pool1. You can switch to that choice by setting
# slim.arg_scope([slim.max_pool2d], padding='VALID').
with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
return arg_sc

+ 108
- 0
modelscope/pipelines/cv/ocr_utils/utils.py View File

@@ -0,0 +1,108 @@
import cv2
import numpy as np


def rboxes_to_polygons(rboxes):
"""
Convert rboxes to polygons
ARGS
`rboxes`: [n, 5]
RETURN
`polygons`: [n, 8]
"""

theta = rboxes[:, 4:5]
cxcy = rboxes[:, :2]
half_w = rboxes[:, 2:3] / 2.
half_h = rboxes[:, 3:4] / 2.
v1 = np.hstack([np.cos(theta) * half_w, np.sin(theta) * half_w])
v2 = np.hstack([-np.sin(theta) * half_h, np.cos(theta) * half_h])
p1 = cxcy - v1 - v2
p2 = cxcy + v1 - v2
p3 = cxcy + v1 + v2
p4 = cxcy - v1 + v2
polygons = np.hstack([p1, p2, p3, p4])
return polygons


def cal_width(box):
pd1 = point_dist(box[0], box[1], box[2], box[3])
pd2 = point_dist(box[4], box[5], box[6], box[7])
return (pd1 + pd2) / 2


def point_dist(x1, y1, x2, y2):
return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1))


def draw_polygons(img, polygons):
for p in polygons.tolist():
p = [int(o) for o in p]
cv2.line(img, (p[0], p[1]), (p[2], p[3]), (0, 255, 0), 1)
cv2.line(img, (p[2], p[3]), (p[4], p[5]), (0, 255, 0), 1)
cv2.line(img, (p[4], p[5]), (p[6], p[7]), (0, 255, 0), 1)
cv2.line(img, (p[6], p[7]), (p[0], p[1]), (0, 255, 0), 1)
return img


def nms_python(boxes):
boxes = sorted(boxes, key=lambda x: -x[8])
nms_flag = [True] * len(boxes)
for i, a in enumerate(boxes):
if not nms_flag[i]:
continue
else:
for j, b in enumerate(boxes):
if not j > i:
continue
if not nms_flag[j]:
continue
score_a = a[8]
score_b = b[8]
rbox_a = polygon2rbox(a[:8])
rbox_b = polygon2rbox(b[:8])
if point_in_rbox(rbox_a[:2], rbox_b) or point_in_rbox(
rbox_b[:2], rbox_a):
if score_a > score_b:
nms_flag[j] = False
boxes_nms = []
for i, box in enumerate(boxes):
if nms_flag[i]:
boxes_nms.append(box)
return boxes_nms


def point_in_rbox(c, rbox):
cx0, cy0 = c[0], c[1]
cx1, cy1 = rbox[0], rbox[1]
w, h = rbox[2], rbox[3]
theta = rbox[4]
dist_x = np.abs((cx1 - cx0) * np.cos(theta) + (cy1 - cy0) * np.sin(theta))
dist_y = np.abs(-(cx1 - cx0) * np.sin(theta) + (cy1 - cy0) * np.cos(theta))
return ((dist_x < w / 2.0) and (dist_y < h / 2.0))


def polygon2rbox(polygon):
x1, x2, x3, x4 = polygon[0], polygon[2], polygon[4], polygon[6]
y1, y2, y3, y4 = polygon[1], polygon[3], polygon[5], polygon[7]
c_x = (x1 + x2 + x3 + x4) / 4
c_y = (y1 + y2 + y3 + y4) / 4
w1 = point_dist(x1, y1, x2, y2)
w2 = point_dist(x3, y3, x4, y4)
h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2)
h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4)
h = h1 + h2
w = (w1 + w2) / 2
theta1 = np.arctan2(y2 - y1, x2 - x1)
theta2 = np.arctan2(y3 - y4, x3 - x4)
theta = (theta1 + theta2) / 2.0
return [c_x, c_y, w, h, theta]


def point_line_dist(px, py, x1, y1, x2, y2):
eps = 1e-6
dx = x2 - x1
dy = y2 - y1
div = np.sqrt(dx * dx + dy * dy) + eps
dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div
return dist

+ 7
- 0
modelscope/pipelines/outputs.py View File

@@ -54,6 +54,13 @@ TASK_OUTPUTS = {
# }
Tasks.pose_estimation: ['poses', 'boxes'],

# ocr detection result for single sample
# {
# "det_polygons": np.array with shape [num_text, 8], each box is
# [x1, y1, x2, y2, x3, y3, x4, y4]
# }
Tasks.ocr_detection: ['det_polygons'],

# ============ nlp tasks ===================

# text classification result for single sample


+ 1
- 0
modelscope/utils/constant.py View File

@@ -28,6 +28,7 @@ class Tasks(object):
image_editing = 'image-editing'
image_generation = 'image-generation'
image_matting = 'image-matting'
ocr_detection = 'ocr-detection'

# nlp tasks
word_segmentation = 'word-segmentation'


+ 1
- 0
requirements/cv.txt View File

@@ -1 +1,2 @@
easydict
tf_slim

+ 37
- 0
tests/pipelines/test_ocr_detection.py View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import shutil
import sys
import tempfile
import unittest
from typing import Any, Dict, List, Tuple, Union

import cv2
import numpy as np
import PIL

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class OCRDetectionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_resnet18_ocr-detection-line-level_damo'
self.test_image = 'data/test/images/ocr_detection.jpg'

def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
print('ocr detection results: ')
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_modelhub_default_model(self):
ocr_detection = pipeline(Tasks.ocr_detection)
self.pipeline_inference(ocr_detection, self.test_image)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save