Browse Source

[to #42322933]Merge branch 'master' into ocr/ocr_detection

修复master分支ocr_detection 单元测试bug
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9112290

    * create ocr_detection task

* fix code check error

* fix code check error

* fix code check issue

* fix code check issue

* replace c++ nms with python version

* fix code check issue

* fix code check issue

* rename maas_lib

* merge master to ocr/ocr_detection

* add model_hub sup for ocr_detection

* fix bug

* replace c++ decoder with python version

* fix bug

* Merge branch 'master' into ocr/ocr_detection

* merge master

* fix code check

* update

* add requirements for ocr_detection

* fix model_hub fetch bug

* remove debug code

* Merge branch 'master' into ocr/ocr_detection

* add local test image for ocr_detection

* update requirements for model_hub

* Merge branch 'master' into ocr/ocr_detection

* fix bug for full case test

* remove ema for ocr_detection

* Merge branch 'master' into ocr/ocr_detection

* apply ocr_detection test case

* Merge branch 'master' into ocr/ocr_detection

* update slim dependency for ocr_detection

* add more test case for ocr_detection

* release tf graph before create

* recover ema for ocr_detection model

* fix code

* Merge branch 'master' into ocr/ocr_detection

* fix code
master
xixing.tj huangjun.hj 3 years ago
parent
commit
d4692b5ada
5 changed files with 72 additions and 45 deletions
  1. +52
    -42
      modelscope/pipelines/cv/ocr_detection_pipeline.py
  2. +5
    -1
      modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py
  3. +5
    -1
      modelscope/pipelines/cv/ocr_utils/resnet18_v1.py
  4. +5
    -1
      modelscope/pipelines/cv/ocr_utils/resnet_utils.py
  5. +5
    -0
      tests/pipelines/test_ocr_detection.py

+ 52
- 42
modelscope/pipelines/cv/ocr_detection_pipeline.py View File

@@ -8,7 +8,6 @@ import cv2
import numpy as np
import PIL
import tensorflow as tf
import tf_slim as slim

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input
@@ -19,6 +18,11 @@ from ..base import Pipeline
from ..builder import PIPELINES
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils

if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim

if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.compat.v1.disable_eager_execution()
@@ -44,6 +48,7 @@ class OCRDetectionPipeline(Pipeline):

def __init__(self, model: str):
super().__init__(model=model)
tf.reset_default_graph()
model_path = osp.join(
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER),
'checkpoint-80000')
@@ -51,51 +56,56 @@ class OCRDetectionPipeline(Pipeline):
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
self._session = tf.Session(config=config)
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0),
dtype=tf.int64,
trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)
self.input_images = tf.placeholder(
tf.float32, shape=[1, 1024, 1024, 3], name='input_images')
self.output = {}

# detector
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector()
all_maps = detector.build_model(self.input_images, is_training=False)

# decode local predictions
all_nodes, all_links, all_reg = [], [], []
for i, maps in enumerate(all_maps):
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2]
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)

cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))

lnk_prob_pos = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, :2])
lnk_prob_mut = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, 2:])
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1)

all_nodes.append(cls_prob)
all_links.append(lnk_prob)
all_reg.append(reg_maps)

# decode segments and links
image_size = tf.shape(self.input_images)[1:3]
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python(
image_size,
all_nodes,
all_links,
all_reg,
anchor_sizes=list(detector.anchor_sizes))

# combine segments
combined_rboxes, combined_counts = ops.combine_segments_python(
segments, group_indices, segment_counts)
self.output['combined_rboxes'] = combined_rboxes
self.output['combined_counts'] = combined_counts
with tf.variable_scope('', reuse=tf.AUTO_REUSE):
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0),
dtype=tf.int64,
trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)

# detector
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector()
all_maps = detector.build_model(
self.input_images, is_training=False)

# decode local predictions
all_nodes, all_links, all_reg = [], [], []
for i, maps in enumerate(all_maps):
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2]
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)

cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))

lnk_prob_pos = tf.nn.softmax(
tf.reshape(lnk_maps, [-1, 4])[:, :2])
lnk_prob_mut = tf.nn.softmax(
tf.reshape(lnk_maps, [-1, 4])[:, 2:])
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1)

all_nodes.append(cls_prob)
all_links.append(lnk_prob)
all_reg.append(reg_maps)

# decode segments and links
image_size = tf.shape(self.input_images)[1:3]
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python(
image_size,
all_nodes,
all_links,
all_reg,
anchor_sizes=list(detector.anchor_sizes))

# combine segments
combined_rboxes, combined_counts = ops.combine_segments_python(
segments, group_indices, segment_counts)
self.output['combined_rboxes'] = combined_rboxes
self.output['combined_counts'] = combined_counts

with self._session.as_default() as sess:
logger.info(f'loading model from {model_path}')


+ 5
- 1
modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py View File

@@ -1,8 +1,12 @@
import tensorflow as tf
import tf_slim as slim

from . import ops, resnet18_v1, resnet_utils

if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim

if tf.__version__ >= '2.0':
tf = tf.compat.v1



+ 5
- 1
modelscope/pipelines/cv/ocr_utils/resnet18_v1.py View File

@@ -30,10 +30,14 @@ ResNet-101 for semantic segmentation into 21 classes:
output_stride=16)
"""
import tensorflow as tf
import tf_slim as slim

from . import resnet_utils

if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim

if tf.__version__ >= '2.0':
tf = tf.compat.v1



+ 5
- 1
modelscope/pipelines/cv/ocr_utils/resnet_utils.py View File

@@ -19,7 +19,11 @@ implementation is more memory efficient.
import collections

import tensorflow as tf
import tf_slim as slim

if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim

if tf.__version__ >= '2.0':
tf = tf.compat.v1


+ 5
- 0
tests/pipelines/test_ocr_detection.py View File

@@ -27,6 +27,11 @@ class OCRDetectionTest(unittest.TestCase):
print('ocr detection results: ')
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
ocr_detection = pipeline(Tasks.ocr_detection, model=self.model_id)
self.pipeline_inference(ocr_detection, self.test_image)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
ocr_detection = pipeline(Tasks.ocr_detection)


Loading…
Cancel
Save