zhangzhicheng.zzc yingda.chen 2 years ago
parent
commit
7fc49e5fa0
11 changed files with 1142 additions and 0 deletions
  1. +3
    -0
      data/test/images/table_recognition.jpg
  2. +1
    -0
      modelscope/metainfo.py
  3. +1
    -0
      modelscope/outputs/outputs.py
  4. +3
    -0
      modelscope/pipelines/builder.py
  5. +2
    -0
      modelscope/pipelines/cv/__init__.py
  6. +655
    -0
      modelscope/pipelines/cv/ocr_utils/model_dla34.py
  7. +315
    -0
      modelscope/pipelines/cv/ocr_utils/table_process.py
  8. +119
    -0
      modelscope/pipelines/cv/table_recognition_pipeline.py
  9. +1
    -0
      modelscope/utils/constant.py
  10. +41
    -0
      tests/pipelines/test_table_recognition.py
  11. +1
    -0
      tests/run_config.yaml

+ 3
- 0
data/test/images/table_recognition.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f4b7e23f02a35136707ac7862e0a8468797f239e89497351847cfacb2a9c24f6
size 202112

+ 1
- 0
modelscope/metainfo.py View File

@@ -151,6 +151,7 @@ class Pipelines(object):
image_denoise = 'nafnet-image-denoise'
person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection'
table_recognition = 'dla34-table-recognition'
action_recognition = 'TAdaConv_action-recognition'
animal_recognition = 'resnet101-animal-recognition'
general_recognition = 'resnet101-general-recognition'


+ 1
- 0
modelscope/outputs/outputs.py View File

@@ -59,6 +59,7 @@ TASK_OUTPUTS = {
# [x1, y1, x2, y2, x3, y3, x4, y4]
# }
Tasks.ocr_detection: [OutputKeys.POLYGONS],
Tasks.table_recognition: [OutputKeys.POLYGONS],

# ocr recognition result for single sample
# {


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -82,6 +82,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.table_recognition:
(Pipelines.table_recognition,
'damo/cv_dla34_table-structure-recognition_cycle-centernet'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.feature_extraction: (Pipelines.feature_extraction,
'damo/pert_feature-extraction_base-test'),


+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -41,6 +41,7 @@ if TYPE_CHECKING:
from .live_category_pipeline import LiveCategoryPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
from .ocr_recognition_pipeline import OCRRecognitionPipeline
from .table_recognition_pipeline import TableRecognitionPipeline
from .skin_retouching_pipeline import SkinRetouchingPipeline
from .tinynas_classification_pipeline import TinynasClassificationPipeline
from .video_category_pipeline import VideoCategoryPipeline
@@ -108,6 +109,7 @@ else:
'image_inpainting_pipeline': ['ImageInpaintingPipeline'],
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'],
'table_recognition_pipeline': ['TableRecognitionPipeline'],
'skin_retouching_pipeline': ['SkinRetouchingPipeline'],
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'],
'video_category_pipeline': ['VideoCategoryPipeline'],


+ 655
- 0
modelscope/pipelines/cv/ocr_utils/model_dla34.py View File

@@ -0,0 +1,655 @@
# ------------------------------------------------------------------------------
# The implementation is adopted from CenterNet,
# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git
# ------------------------------------------------------------------------------

import math
from os.path import join

import numpy as np
import torch
from torch import nn

BatchNorm = nn.BatchNorm2d


class BasicBlock(nn.Module):

def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation)
self.bn1 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=1,
padding=dilation,
bias=False,
dilation=dilation)
self.bn2 = BatchNorm(planes)
self.stride = stride

def forward(self, x, residual=None):
if residual is None:
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

out += residual
out = self.relu(out)

return out


class Bottleneck(nn.Module):
expansion = 2

def __init__(self, inplanes, planes, stride=1, dilation=1):
super(Bottleneck, self).__init__()
expansion = Bottleneck.expansion
bottle_planes = planes // expansion
self.conv1 = nn.Conv2d(
inplanes, bottle_planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm(bottle_planes)
self.conv2 = nn.Conv2d(
bottle_planes,
bottle_planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation)
self.bn2 = BatchNorm(bottle_planes)
self.conv3 = nn.Conv2d(
bottle_planes, planes, kernel_size=1, bias=False)
self.bn3 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.stride = stride

def forward(self, x, residual=None):
if residual is None:
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

out += residual
out = self.relu(out)

return out


class BottleneckX(nn.Module):
expansion = 2
cardinality = 32

def __init__(self, inplanes, planes, stride=1, dilation=1):
super(BottleneckX, self).__init__()
cardinality = BottleneckX.cardinality
bottle_planes = planes * cardinality // 32
self.conv1 = nn.Conv2d(
inplanes, bottle_planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm(bottle_planes)
self.conv2 = nn.Conv2d(
bottle_planes,
bottle_planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=False,
dilation=dilation,
groups=cardinality)
self.bn2 = BatchNorm(bottle_planes)
self.conv3 = nn.Conv2d(
bottle_planes, planes, kernel_size=1, bias=False)
self.bn3 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.stride = stride

def forward(self, x, residual=None):
if residual is None:
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

out += residual
out = self.relu(out)

return out


class Root(nn.Module):

def __init__(self, in_channels, out_channels, kernel_size, residual):
super(Root, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
1,
stride=1,
bias=False,
padding=(kernel_size - 1) // 2)
self.bn = BatchNorm(out_channels)
self.relu = nn.ReLU(inplace=True)
self.residual = residual

def forward(self, *x):
children = x
x = self.conv(torch.cat(x, 1))
x = self.bn(x)
if self.residual:
x += children[0]
x = self.relu(x)

return x


class Tree(nn.Module):

def __init__(self,
levels,
block,
in_channels,
out_channels,
stride=1,
level_root=False,
root_dim=0,
root_kernel_size=1,
dilation=1,
root_residual=False):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if levels == 1:
self.tree1 = block(
in_channels, out_channels, stride, dilation=dilation)
self.tree2 = block(
out_channels, out_channels, 1, dilation=dilation)
else:
self.tree1 = Tree(
levels - 1,
block,
in_channels,
out_channels,
stride,
root_dim=0,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual)
self.tree2 = Tree(
levels - 1,
block,
out_channels,
out_channels,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual)
if levels == 1:
self.root = Root(root_dim, out_channels, root_kernel_size,
root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
self.levels = levels
if stride > 1:
self.downsample = nn.MaxPool2d(stride, stride=stride)
if in_channels != out_channels:
self.project = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
bias=False), BatchNorm(out_channels))

def forward(self, x, residual=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
residual = self.project(bottom) if self.project else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
if self.levels == 1:
x2 = self.tree2(x1)
x = self.root(x2, x1, *children)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x


class DLA(nn.Module):

def __init__(self,
levels,
channels,
num_classes=1000,
block=BasicBlock,
residual_root=False,
return_levels=False,
pool_size=7,
linear_root=False):
super(DLA, self).__init__()
self.channels = channels
self.return_levels = return_levels
self.num_classes = num_classes
self.base_layer = nn.Sequential(
nn.Conv2d(
3, channels[0], kernel_size=7, stride=1, padding=3,
bias=False), BatchNorm(channels[0]), nn.ReLU(inplace=True))
self.level0 = self._make_conv_level(channels[0], channels[0],
levels[0])
self.level1 = self._make_conv_level(
channels[0], channels[1], levels[1], stride=2)
self.level2 = Tree(
levels[2],
block,
channels[1],
channels[2],
2,
level_root=False,
root_residual=residual_root)
self.level3 = Tree(
levels[3],
block,
channels[2],
channels[3],
2,
level_root=True,
root_residual=residual_root)
self.level4 = Tree(
levels[4],
block,
channels[3],
channels[4],
2,
level_root=True,
root_residual=residual_root)
self.level5 = Tree(
levels[5],
block,
channels[4],
channels[5],
2,
level_root=True,
root_residual=residual_root)

self.avgpool = nn.AvgPool2d(pool_size)
self.fc = nn.Conv2d(
channels[-1],
num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, BatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_level(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes:
downsample = nn.Sequential(
nn.MaxPool2d(stride, stride=stride),
nn.Conv2d(
inplanes, planes, kernel_size=1, stride=1, bias=False),
BatchNorm(planes),
)

layers = []
layers.append(block(inplanes, planes, stride, downsample=downsample))
for i in range(1, blocks):
layers.append(block(inplanes, planes))

return nn.Sequential(*layers)

def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
modules = []
for i in range(convs):
modules.extend([
nn.Conv2d(
inplanes,
planes,
kernel_size=3,
stride=stride if i == 0 else 1,
padding=dilation,
bias=False,
dilation=dilation),
BatchNorm(planes),
nn.ReLU(inplace=True)
])
inplanes = planes
return nn.Sequential(*modules)

def forward(self, x):
y = []
x = self.base_layer(x)
for i in range(6):
x = getattr(self, 'level{}'.format(i))(x)
y.append(x)
if self.return_levels:
return y
else:
x = self.avgpool(x)
x = self.fc(x)
x = x.view(x.size(0), -1)

return x


def dla34(pretrained, **kwargs): # DLA-34
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512],
block=BasicBlock,
**kwargs)
return model


def dla46_c(pretrained=None, **kwargs): # DLA-46-C
Bottleneck.expansion = 2
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256],
block=Bottleneck,
**kwargs)
return model


def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C
BottleneckX.expansion = 2
model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256],
block=BottleneckX,
**kwargs)
return model


def dla60x_c(pretrained, **kwargs): # DLA-X-60-C
BottleneckX.expansion = 2
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 64, 64, 128, 256],
block=BottleneckX,
**kwargs)
return model


def dla60(pretrained=None, **kwargs): # DLA-60
Bottleneck.expansion = 2
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024],
block=Bottleneck,
**kwargs)
return model


def dla60x(pretrained=None, **kwargs): # DLA-X-60
BottleneckX.expansion = 2
model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024],
block=BottleneckX,
**kwargs)
return model


def dla102(pretrained=None, **kwargs): # DLA-102
Bottleneck.expansion = 2
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=Bottleneck,
residual_root=True,
**kwargs)
return model


def dla102x(pretrained=None, **kwargs): # DLA-X-102
BottleneckX.expansion = 2
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=BottleneckX,
residual_root=True,
**kwargs)
return model


def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64
BottleneckX.cardinality = 64
model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
block=BottleneckX,
residual_root=True,
**kwargs)
return model


def dla169(pretrained=None, **kwargs): # DLA-169
Bottleneck.expansion = 2
model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024],
block=Bottleneck,
residual_root=True,
**kwargs)
return model


def set_bn(bn):
global BatchNorm
BatchNorm = bn
dla.BatchNorm = bn


class Identity(nn.Module):

def __init__(self):
super(Identity, self).__init__()

def forward(self, x):
return x


def fill_up_weights(up):
w = up.weight.data
f = math.ceil(w.size(2) / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(w.size(2)):
for j in range(w.size(3)):
w[0, 0, i, j] = \
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, w.size(0)):
w[c, 0, :, :] = w[0, 0, :, :]


class IDAUp(nn.Module):

def __init__(self, node_kernel, out_dim, channels, up_factors):
super(IDAUp, self).__init__()
self.channels = channels
self.out_dim = out_dim
for i, c in enumerate(channels):
if c == out_dim:
proj = Identity()
else:
proj = nn.Sequential(
nn.Conv2d(c, out_dim, kernel_size=1, stride=1, bias=False),
BatchNorm(out_dim), nn.ReLU(inplace=True))
f = int(up_factors[i])
if f == 1:
up = Identity()
else:
up = nn.ConvTranspose2d(
out_dim,
out_dim,
f * 2,
stride=f,
padding=f // 2,
output_padding=0,
groups=out_dim,
bias=False)
fill_up_weights(up)
setattr(self, 'proj_' + str(i), proj)
setattr(self, 'up_' + str(i), up)

for i in range(1, len(channels)):
node = nn.Sequential(
nn.Conv2d(
out_dim * 2,
out_dim,
kernel_size=node_kernel,
stride=1,
padding=node_kernel // 2,
bias=False), BatchNorm(out_dim), nn.ReLU(inplace=True))
setattr(self, 'node_' + str(i), node)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, BatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, layers):
assert len(self.channels) == len(layers), \
'{} vs {} layers'.format(len(self.channels), len(layers))
layers = list(layers)
for i, l in enumerate(layers):
upsample = getattr(self, 'up_' + str(i))
project = getattr(self, 'proj_' + str(i))
layers[i] = upsample(project(l))
x = layers[0]
y = []
for i in range(1, len(layers)):
node = getattr(self, 'node_' + str(i))
x = node(torch.cat([x, layers[i]], 1))
y.append(x)
return x, y


class DLAUp(nn.Module):

def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None):
super(DLAUp, self).__init__()
if in_channels is None:
in_channels = channels
self.channels = channels
channels = list(channels)
scales = np.array(scales, dtype=int)
for i in range(len(channels) - 1):
j = -i - 2
setattr(
self, 'ida_{}'.format(i),
IDAUp(3, channels[j], in_channels[j:],
scales[j:] // scales[j]))
scales[j + 1:] = scales[j]
in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]

def forward(self, layers):
layers = list(layers)
assert len(layers) > 1
for i in range(len(layers) - 1):
ida = getattr(self, 'ida_{}'.format(i))
x, y = ida(layers[-i - 2:])
layers[-i - 1:] = y
return x


def fill_fc_weights(layers):
for m in layers.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)


class DLASeg(nn.Module):

def __init__(self,
base_name='dla34',
pretrained=False,
down_ratio=4,
head_conv=256):
super(DLASeg, self).__init__()
assert down_ratio in [2, 4, 8, 16]
self.heads = {'hm': 2, 'v2c': 8, 'c2v': 8, 'reg': 2}
self.first_level = int(np.log2(down_ratio))
self.base = globals()[base_name](
pretrained=pretrained, return_levels=True)
channels = self.base.channels
scales = [2**i for i in range(len(channels[self.first_level:]))]
self.dla_up = DLAUp(channels[self.first_level:], scales=scales)

for head in self.heads:
classes = self.heads[head]
if head_conv > 0:
fc = nn.Sequential(
nn.Conv2d(
channels[self.first_level],
head_conv,
kernel_size=3,
padding=1,
bias=True), nn.ReLU(inplace=True),
nn.Conv2d(
head_conv,
classes,
kernel_size=1,
stride=1,
padding=0,
bias=True))
if 'hm' in head:
fc[-1].bias.data.fill_(-2.19)
else:
fill_fc_weights(fc)
else:
fc = nn.Conv2d(
channels[self.first_level],
classes,
kernel_size=1,
stride=1,
padding=0,
bias=True)
if 'hm' in head:
fc.bias.data.fill_(-2.19)
else:
fill_fc_weights(fc)
self.__setattr__(head, fc)

def forward(self, x):
x = self.base(x)
x = self.dla_up(x[self.first_level:])
ret = {}
for head in self.heads:
ret[head] = self.__getattr__(head)(x)
return [ret]


def TableRecModel():
model = DLASeg()
return model

+ 315
- 0
modelscope/pipelines/cv/ocr_utils/table_process.py View File

@@ -0,0 +1,315 @@
# ------------------------------------------------------------------------------
# The implementation is adopted from CenterNet,
# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git
# ------------------------------------------------------------------------------

import copy
import math
import random

import cv2
import numpy as np
import torch
import torch.nn as nn


def transform_preds(coords, center, scale, output_size, rot=0):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, rot, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords


def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale], dtype=np.float32)

scale_tmp = scale
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]

rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)

src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir

src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])

if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))

return trans


def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]


def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)

src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs

return src_result


def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)


def _sigmoid(x):
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
return y


def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat


def _tranpose_and_gather_feat(feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = _gather_feat(feat, ind)
return feat


def _nms(heat, kernel=3):
pad = (kernel - 1) // 2

hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep, keep


def _topk(scores, K=40):
batch, cat, height, width = scores.size()

topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()

topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_clses = (topk_ind / K).int()
topk_inds = _gather_feat(topk_inds.view(batch, -1, 1),
topk_ind).view(batch, K)
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

return topk_score, topk_inds, topk_clses, topk_ys, topk_xs


def bbox_decode(heat, wh, reg=None, K=100):
batch, cat, height, width = heat.size()

heat, keep = _nms(heat)

scores, inds, clses, ys, xs = _topk(heat, K=K)
if reg is not None:
reg = _tranpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
wh = _tranpose_and_gather_feat(wh, inds)
wh = wh.view(batch, K, 8)
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)

bboxes = torch.cat(
[
xs - wh[..., 0:1],
ys - wh[..., 1:2],
xs - wh[..., 2:3],
ys - wh[..., 3:4],
xs - wh[..., 4:5],
ys - wh[..., 5:6],
xs - wh[..., 6:7],
ys - wh[..., 7:8],
],
dim=2,
)
detections = torch.cat([bboxes, scores, clses], dim=2)

return detections, keep


def gbox_decode(mk, st_reg, reg=None, K=400):
batch, cat, height, width = mk.size()
mk, keep = _nms(mk)
scores, inds, clses, ys, xs = _topk(mk, K=K)
if reg is not None:
reg = _tranpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
scores = scores.view(batch, K, 1)
clses = clses.view(batch, K, 1).float()
st_Reg = _tranpose_and_gather_feat(st_reg, inds)
bboxes = torch.cat(
[
xs - st_Reg[..., 0:1],
ys - st_Reg[..., 1:2],
xs - st_Reg[..., 2:3],
ys - st_Reg[..., 3:4],
xs - st_Reg[..., 4:5],
ys - st_Reg[..., 5:6],
xs - st_Reg[..., 6:7],
ys - st_Reg[..., 7:8],
],
dim=2,
)
return torch.cat([xs, ys, bboxes, scores, clses], dim=2), keep


def bbox_post_process(bbox, c, s, h, w):
for i in range(bbox.shape[0]):
bbox[i, :, 0:2] = transform_preds(bbox[i, :, 0:2], c[i], s[i], (w, h))
bbox[i, :, 2:4] = transform_preds(bbox[i, :, 2:4], c[i], s[i], (w, h))
bbox[i, :, 4:6] = transform_preds(bbox[i, :, 4:6], c[i], s[i], (w, h))
bbox[i, :, 6:8] = transform_preds(bbox[i, :, 6:8], c[i], s[i], (w, h))
return bbox


def gbox_post_process(gbox, c, s, h, w):
for i in range(gbox.shape[0]):
gbox[i, :, 0:2] = transform_preds(gbox[i, :, 0:2], c[i], s[i], (w, h))
gbox[i, :, 2:4] = transform_preds(gbox[i, :, 2:4], c[i], s[i], (w, h))
gbox[i, :, 4:6] = transform_preds(gbox[i, :, 4:6], c[i], s[i], (w, h))
gbox[i, :, 6:8] = transform_preds(gbox[i, :, 6:8], c[i], s[i], (w, h))
gbox[i, :, 8:10] = transform_preds(gbox[i, :, 8:10], c[i], s[i],
(w, h))
return gbox


def nms(dets, thresh):
if len(dets) < 2:
return dets
index_keep = []
keep = []
for i in range(len(dets)):
box = dets[i]
if box[-1] < thresh:
break
max_score_index = -1
ctx = (dets[i][0] + dets[i][2] + dets[i][4] + dets[i][6]) / 4
cty = (dets[i][1] + dets[i][3] + dets[i][5] + dets[i][7]) / 4
for j in range(len(dets)):
if i == j or dets[j][-1] < thresh:
break
x1, y1 = dets[j][0], dets[j][1]
x2, y2 = dets[j][2], dets[j][3]
x3, y3 = dets[j][4], dets[j][5]
x4, y4 = dets[j][6], dets[j][7]
a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1)
b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2)
c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3)
d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4)
if (a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0
and c < 0 and d < 0):
if dets[i][8] > dets[j][8] and max_score_index < 0:
max_score_index = i
elif dets[i][8] < dets[j][8]:
max_score_index = -2
break
if max_score_index > -1:
index_keep.append(max_score_index)
elif max_score_index == -1:
index_keep.append(i)
for i in range(0, len(index_keep)):
keep.append(dets[index_keep[i]])
return np.array(keep)


def group_bbox_by_gbox(bboxes,
gboxes,
score_thred=0.3,
v2c_dist_thred=2,
c2v_dist_thred=0.5):

def point_in_box(box, point):
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
x3, y3, x4, y4 = box[4], box[5], box[6], box[7]
ctx, cty = point[0], point[1]
a = (x2 - x1) * (cty - y1) - (y2 - y1) * (ctx - x1)
b = (x3 - x2) * (cty - y2) - (y3 - y2) * (ctx - x2)
c = (x4 - x3) * (cty - y3) - (y4 - y3) * (ctx - x3)
d = (x1 - x4) * (cty - y4) - (y1 - y4) * (ctx - x4)
if (a > 0 and b > 0 and c > 0 and d > 0) or (a < 0 and b < 0 and c < 0
and d < 0):
return True
else:
return False

def get_distance(pt1, pt2):
return math.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0])
+ (pt1[1] - pt2[1]) * (pt1[1] - pt2[1]))

dets = copy.deepcopy(bboxes)
sign = np.zeros((len(dets), 4))

for idx, gbox in enumerate(gboxes): # vertex x,y, gbox, score
if gbox[10] < score_thred:
break
vertex = [gbox[0], gbox[1]]
for i in range(0, 4):
center = [gbox[2 * i + 2], gbox[2 * i + 3]]
if get_distance(vertex, center) < v2c_dist_thred:
continue
for k, bbox in enumerate(dets):
if bbox[8] < score_thred:
break
if sum(sign[k]) == 4:
continue
w = (abs(bbox[6] - bbox[0]) + abs(bbox[4] - bbox[2])) / 2
h = (abs(bbox[3] - bbox[1]) + abs(bbox[5] - bbox[7])) / 2
m = max(w, h)
if point_in_box(bbox, center):
min_dist, min_id = 1e4, -1
for j in range(0, 4):
dist = get_distance(vertex,
[bbox[2 * j], bbox[2 * j + 1]])
if dist < min_dist:
min_dist = dist
min_id = j
if (min_id > -1 and min_dist < c2v_dist_thred * m
and sign[k][min_id] == 0):
bboxes[k][2 * min_id] = vertex[0]
bboxes[k][2 * min_id + 1] = vertex[1]
sign[k][min_id] = 1
return bboxes

+ 119
- 0
modelscope/pipelines/cv/table_recognition_pipeline.py View File

@@ -0,0 +1,119 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.cv.ocr_utils.model_dla34 import TableRecModel
from modelscope.pipelines.cv.ocr_utils.table_process import (
bbox_decode, bbox_post_process, gbox_decode, gbox_post_process,
get_affine_transform, group_bbox_by_gbox, nms)
from modelscope.preprocessors import load_image
from modelscope.preprocessors.image import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.table_recognition, module_name=Pipelines.table_recognition)
class TableRecognitionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')

self.K = 1000
self.MK = 4000
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.infer_model = TableRecModel().to(self.device)
self.infer_model.eval()
checkpoint = torch.load(model_path, map_location=self.device)
if 'state_dict' in checkpoint:
self.infer_model.load_state_dict(checkpoint['state_dict'])
else:
self.infer_model.load_state_dict(checkpoint)

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)

mean = np.array([0.408, 0.447, 0.470],
dtype=np.float32).reshape(1, 1, 3)
std = np.array([0.289, 0.274, 0.278],
dtype=np.float32).reshape(1, 1, 3)
height, width = img.shape[0:2]
inp_height, inp_width = 1024, 1024
c = np.array([width / 2., height / 2.], dtype=np.float32)
s = max(height, width) * 1.0

trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
resized_image = cv2.resize(img, (width, height))
inp_image = cv2.warpAffine(
resized_image,
trans_input, (inp_width, inp_height),
flags=cv2.INTER_LINEAR)
inp_image = ((inp_image / 255. - mean) / std).astype(np.float32)

images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height,
inp_width)
images = torch.from_numpy(images).to(self.device)
meta = {
'c': c,
's': s,
'input_height': inp_height,
'input_width': inp_width,
'out_height': inp_height // 4,
'out_width': inp_width // 4
}

result = {'img': images, 'meta': meta}

return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
pred = self.infer_model(input['img'])
return {'results': pred, 'meta': input['meta']}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
output = inputs['results'][0]
meta = inputs['meta']
hm = output['hm'].sigmoid_()
v2c = output['v2c']
c2v = output['c2v']
reg = output['reg']
bbox, _ = bbox_decode(hm[:, 0:1, :, :], c2v, reg=reg, K=self.K)
gbox, _ = gbox_decode(hm[:, 1:2, :, :], v2c, reg=reg, K=self.MK)

bbox = bbox.detach().cpu().numpy()
gbox = gbox.detach().cpu().numpy()
bbox = nms(bbox, 0.3)
bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()],
[meta['s']], meta['out_height'],
meta['out_width'])
gbox = gbox_post_process(gbox.copy(), [meta['c'].cpu().numpy()],
[meta['s']], meta['out_height'],
meta['out_width'])
bbox = group_bbox_by_gbox(bbox[0], gbox[0])

res = []
for box in bbox:
if box[8] > 0.3:
res.append(box[0:8])

result = {OutputKeys.POLYGONS: np.array(res)}
return result

+ 1
- 0
modelscope/utils/constant.py View File

@@ -16,6 +16,7 @@ class CVTasks(object):
# ocr
ocr_detection = 'ocr-detection'
ocr_recognition = 'ocr-recognition'
table_recognition = 'table-recognition'

# human face body related
animal_recognition = 'animal-recognition'


+ 41
- 0
tests/pipelines/test_table_recognition.py View File

@@ -0,0 +1,41 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class TableRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.model_id = 'damo/cv_dla34_table-structure-recognition_cycle-centernet'
self.test_image = 'data/test/images/table_recognition.jpg'
self.task = Tasks.table_recognition

def pipeline_inference(self, pipe: Pipeline, input_location: str):
result = pipe(input_location)
print('table recognition results: ')
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
table_recognition = pipeline(
Tasks.table_recognition, model=self.model_id)
self.pipeline_inference(table_recognition, self.test_image)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
table_recognition = pipeline(Tasks.table_recognition)
self.pipeline_inference(table_recognition, self.test_image)

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()


if __name__ == '__main__':
unittest.main()

+ 1
- 0
tests/run_config.yaml View File

@@ -39,6 +39,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
- test_automatic_speech_recognition.py
- test_image_matting.py
- test_skin_retouching.py
- test_table_recognition.py

envs:
default: # default env, case not in other env will in default, pytorch.


Loading…
Cancel
Save