|
|
@@ -8,7 +8,6 @@ import cv2 |
|
|
|
import numpy as np |
|
|
|
import PIL |
|
|
|
import tensorflow as tf |
|
|
|
import tf_slim as slim |
|
|
|
|
|
|
|
from modelscope.metainfo import Pipelines |
|
|
|
from modelscope.pipelines.base import Input |
|
|
@@ -19,6 +18,11 @@ from ..base import Pipeline |
|
|
|
from ..builder import PIPELINES |
|
|
|
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils |
|
|
|
|
|
|
|
if tf.__version__ >= '2.0': |
|
|
|
import tf_slim as slim |
|
|
|
else: |
|
|
|
from tensorflow.contrib import slim |
|
|
|
|
|
|
|
if tf.__version__ >= '2.0': |
|
|
|
tf = tf.compat.v1 |
|
|
|
tf.compat.v1.disable_eager_execution() |
|
|
@@ -44,6 +48,7 @@ class OCRDetectionPipeline(Pipeline): |
|
|
|
|
|
|
|
def __init__(self, model: str): |
|
|
|
super().__init__(model=model) |
|
|
|
tf.reset_default_graph() |
|
|
|
model_path = osp.join( |
|
|
|
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), |
|
|
|
'checkpoint-80000') |
|
|
@@ -51,51 +56,56 @@ class OCRDetectionPipeline(Pipeline): |
|
|
|
config = tf.ConfigProto(allow_soft_placement=True) |
|
|
|
config.gpu_options.allow_growth = True |
|
|
|
self._session = tf.Session(config=config) |
|
|
|
global_step = tf.get_variable( |
|
|
|
'global_step', [], |
|
|
|
initializer=tf.constant_initializer(0), |
|
|
|
dtype=tf.int64, |
|
|
|
trainable=False) |
|
|
|
variable_averages = tf.train.ExponentialMovingAverage( |
|
|
|
0.997, global_step) |
|
|
|
self.input_images = tf.placeholder( |
|
|
|
tf.float32, shape=[1, 1024, 1024, 3], name='input_images') |
|
|
|
self.output = {} |
|
|
|
|
|
|
|
# detector |
|
|
|
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector() |
|
|
|
all_maps = detector.build_model(self.input_images, is_training=False) |
|
|
|
|
|
|
|
# decode local predictions |
|
|
|
all_nodes, all_links, all_reg = [], [], [] |
|
|
|
for i, maps in enumerate(all_maps): |
|
|
|
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] |
|
|
|
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) |
|
|
|
|
|
|
|
cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) |
|
|
|
|
|
|
|
lnk_prob_pos = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, :2]) |
|
|
|
lnk_prob_mut = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, 2:]) |
|
|
|
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) |
|
|
|
|
|
|
|
all_nodes.append(cls_prob) |
|
|
|
all_links.append(lnk_prob) |
|
|
|
all_reg.append(reg_maps) |
|
|
|
|
|
|
|
# decode segments and links |
|
|
|
image_size = tf.shape(self.input_images)[1:3] |
|
|
|
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python( |
|
|
|
image_size, |
|
|
|
all_nodes, |
|
|
|
all_links, |
|
|
|
all_reg, |
|
|
|
anchor_sizes=list(detector.anchor_sizes)) |
|
|
|
|
|
|
|
# combine segments |
|
|
|
combined_rboxes, combined_counts = ops.combine_segments_python( |
|
|
|
segments, group_indices, segment_counts) |
|
|
|
self.output['combined_rboxes'] = combined_rboxes |
|
|
|
self.output['combined_counts'] = combined_counts |
|
|
|
with tf.variable_scope('', reuse=tf.AUTO_REUSE): |
|
|
|
global_step = tf.get_variable( |
|
|
|
'global_step', [], |
|
|
|
initializer=tf.constant_initializer(0), |
|
|
|
dtype=tf.int64, |
|
|
|
trainable=False) |
|
|
|
variable_averages = tf.train.ExponentialMovingAverage( |
|
|
|
0.997, global_step) |
|
|
|
|
|
|
|
# detector |
|
|
|
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector() |
|
|
|
all_maps = detector.build_model( |
|
|
|
self.input_images, is_training=False) |
|
|
|
|
|
|
|
# decode local predictions |
|
|
|
all_nodes, all_links, all_reg = [], [], [] |
|
|
|
for i, maps in enumerate(all_maps): |
|
|
|
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] |
|
|
|
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) |
|
|
|
|
|
|
|
cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) |
|
|
|
|
|
|
|
lnk_prob_pos = tf.nn.softmax( |
|
|
|
tf.reshape(lnk_maps, [-1, 4])[:, :2]) |
|
|
|
lnk_prob_mut = tf.nn.softmax( |
|
|
|
tf.reshape(lnk_maps, [-1, 4])[:, 2:]) |
|
|
|
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) |
|
|
|
|
|
|
|
all_nodes.append(cls_prob) |
|
|
|
all_links.append(lnk_prob) |
|
|
|
all_reg.append(reg_maps) |
|
|
|
|
|
|
|
# decode segments and links |
|
|
|
image_size = tf.shape(self.input_images)[1:3] |
|
|
|
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python( |
|
|
|
image_size, |
|
|
|
all_nodes, |
|
|
|
all_links, |
|
|
|
all_reg, |
|
|
|
anchor_sizes=list(detector.anchor_sizes)) |
|
|
|
|
|
|
|
# combine segments |
|
|
|
combined_rboxes, combined_counts = ops.combine_segments_python( |
|
|
|
segments, group_indices, segment_counts) |
|
|
|
self.output['combined_rboxes'] = combined_rboxes |
|
|
|
self.output['combined_counts'] = combined_counts |
|
|
|
|
|
|
|
with self._session.as_default() as sess: |
|
|
|
logger.info(f'loading model from {model_path}') |
|
|
|