Browse Source

[to #42322933]Merge request from 仲理:feat/product_feature

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9515599
master
dangwei.ldw yingda.chen 3 years ago
parent
commit
e93339ea87
15 changed files with 922 additions and 4 deletions
  1. +3
    -0
      data/test/images/product_embed_bag.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +2
    -2
      modelscope/models/cv/__init__.py
  4. +23
    -0
      modelscope/models/cv/product_retrieval_embedding/__init__.py
  5. +517
    -0
      modelscope/models/cv/product_retrieval_embedding/item_detection.py
  6. +157
    -0
      modelscope/models/cv/product_retrieval_embedding/item_embedding.py
  7. +115
    -0
      modelscope/models/cv/product_retrieval_embedding/item_model.py
  8. +6
    -0
      modelscope/outputs.py
  9. +3
    -0
      modelscope/pipelines/builder.py
  10. +4
    -1
      modelscope/pipelines/cv/__init__.py
  11. +45
    -0
      modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py
  12. +1
    -0
      modelscope/utils/constant.py
  13. +3
    -0
      requirements/cv.txt
  14. +2
    -1
      requirements/runtime.txt
  15. +39
    -0
      tests/pipelines/test_product_retrieval_embedding.py

+ 3
- 0
data/test/images/product_embed_bag.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:08691a9373aa6d05b236a4ba788f3eccdea4c37aa77b30fc94b02ec3e1f18210
size 367017

+ 2
- 0
modelscope/metainfo.py View File

@@ -16,6 +16,7 @@ class Models(object):
nafnet = 'nafnet'
csrnet = 'csrnet'
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'
product_retrieval_embedding = 'product-retrieval-embedding'

# nlp models
bert = 'bert'
@@ -84,6 +85,7 @@ class Pipelines(object):
image_super_resolution = 'rrdb-image-super-resolution'
face_image_generation = 'gan-face-image-generation'
style_transfer = 'AAMS-style-transfer'
product_retrieval_embedding = 'resnet50-product-retrieval-embedding'
face_recognition = 'ir101-face-recognition-cfglint'
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
image2image_translation = 'image-to-image-translation'


+ 2
- 2
modelscope/models/cv/__init__.py View File

@@ -3,5 +3,5 @@ from . import (action_recognition, animal_recognition, cartoon,
cmdssl_video_embedding, face_detection, face_generation,
image_classification, image_color_enhance, image_colorization,
image_denoise, image_instance_segmentation,
image_to_image_translation, object_detection, super_resolution,
virual_tryon)
image_to_image_translation, object_detection,
product_retrieval_embedding, super_resolution, virual_tryon)

+ 23
- 0
modelscope/models/cv/product_retrieval_embedding/__init__.py View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:

from .item_model import ProductRetrievalEmbedding

else:
_import_structure = {
'item_model': ['ProductRetrievalEmbedding'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 517
- 0
modelscope/models/cv/product_retrieval_embedding/item_detection.py View File

@@ -0,0 +1,517 @@
import cv2
import numpy as np


class YOLOXONNX(object):
"""
Product detection model with onnx inference
"""

def __init__(self, onnx_path, multi_detect=False):
"""Create product detection model
Args:
onnx_path: onnx model path for product detection
multi_detect: detection parameter, should be set as False

"""
self.input_reso = 416
self.iou_thr = 0.45
self.score_thr = 0.3
self.img_shape = tuple([self.input_reso, self.input_reso, 3])
self.num_classes = 13
self.onnx_path = onnx_path
import onnxruntime as ort
self.ort_session = ort.InferenceSession(self.onnx_path)
self.with_p6 = False
self.multi_detect = multi_detect

def format_judge(self, img):
m_min_width = 100
m_min_height = 100

height, width, c = img.shape

if width * height > 1024 * 1024:
if height > width:
long_side = height
short_side = width
long_ratio = float(long_side) / 1024.0
short_ratio = float(short_side) / float(m_min_width)
else:
long_side = width
short_side = height
long_ratio = float(long_side) / 1024.0
short_ratio = float(short_side) / float(m_min_height)

if long_side == height:
if long_ratio < short_ratio:
height_new = 1024
width_new = (int)((1024 * width) / height)

img_res = cv2.resize(img, (width_new, height_new),
cv2.INTER_LINEAR)
else:
height_new = (int)((m_min_width * height) / width)
width_new = m_min_width

img_res = cv2.resize(img, (width_new, height_new),
cv2.INTER_LINEAR)

elif long_side == width:
if long_ratio < short_ratio:
height_new = (int)((1024 * height) / width)
width_new = 1024

img_res = cv2.resize(img, (width_new, height_new),
cv2.INTER_LINEAR)
else:
width_new = (int)((m_min_height * width) / height)
height_new = m_min_height

img_res = cv2.resize(img, (width_new, height_new),
cv2.INTER_LINEAR)
else:
img_res = img

return img_res

def preprocess(self, image, input_size, swap=(2, 0, 1)):
"""
Args:
image, cv2 image with BGR format
input_size, model input size
"""
if len(image.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0
else:
padded_img = np.ones(input_size) * 114.0
img = np.array(image)
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.float32)
padded_img[:int(img.shape[0] * r), :int(img.shape[1]
* r)] = resized_img

padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r

def cal_iou(self, val1, val2):
x11, y11, x12, y12 = val1
x21, y21, x22, y22 = val2

leftX = max(x11, x21)
topY = max(y11, y21)
rightX = min(x12, x22)
bottomY = min(y12, y22)
if rightX < leftX or bottomY < topY:
return 0
area = float((rightX - leftX) * (bottomY - topY))
barea = (x12 - x11) * (y12 - y11) + (x22 - x21) * (y22 - y21) - area
if barea <= 0:
return 0
return area / barea

def nms(self, boxes, scores, nms_thr):
"""
Single class NMS implemented in Numpy.
"""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]

areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]

keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])

w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)

inds = np.where(ovr <= nms_thr)[0]
order = order[inds + 1]

return keep

def multiclass_nms(self, boxes, scores, nms_thr, score_thr):
"""
Multiclass NMS implemented in Numpy
"""
final_dets = []
num_classes = scores.shape[1]
for cls_ind in range(num_classes):
cls_scores = scores[:, cls_ind]
valid_score_mask = cls_scores > score_thr
if valid_score_mask.sum() == 0:
continue
else:
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
keep = self.nms(valid_boxes, valid_scores, nms_thr)
if len(keep) > 0:
cls_inds = np.ones((len(keep), 1)) * cls_ind
dets = np.concatenate([
valid_boxes[keep], valid_scores[keep, None], cls_inds
], 1)
final_dets.append(dets)
if len(final_dets) == 0:
return None
return np.concatenate(final_dets, 0)

def postprocess(self, outputs, img_size, p6=False):
grids = []
expanded_strides = []

if not p6:
strides = [8, 16, 32]
else:
strides = [8, 16, 32, 64]

hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]

for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))

grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides

return outputs

def get_new_box_order(self, bboxes, labels, img_h, img_w):
"""
refine bbox score
"""
bboxes = np.hstack((bboxes, np.zeros((bboxes.shape[0], 1))))
scores = bboxes[:, 4]
order = scores.argsort()[::-1]
bboxes_temp = bboxes[order]
labels_temp = labels[order]
bboxes = np.empty((0, 6))
# import pdb;pdb.set_trace()
bboxes = np.vstack((bboxes, bboxes_temp[0].tolist()))
labels = np.empty((0, ))

labels = np.hstack((labels, [labels_temp[0]]))
for i in range(1, bboxes_temp.shape[0]):
iou_max = 0
for j in range(bboxes.shape[0]):
iou_temp = self.cal_iou(bboxes_temp[i][:4], bboxes[j][:4])
if (iou_temp > iou_max):
iou_max = iou_temp
if (iou_max < 0.45):
bboxes = np.vstack((bboxes, bboxes_temp[i].tolist()))
labels = np.hstack((labels, [labels_temp[i]]))

num_03 = scores > 0.3
num_03 = num_03.sum()
num_out = max(num_03, 1)
bboxes = bboxes[:num_out, :]
labels = labels[:num_out]

return bboxes, labels

def forward(self, img_input, cid='0', sub_class=False):
"""
forward for product detection
"""
input_shape = self.img_shape

img, ratio = self.preprocess(img_input, input_shape)
img_h, img_w = img_input.shape[:2]

ort_inputs = {
self.ort_session.get_inputs()[0].name: img[None, :, :, :]
}

output = self.ort_session.run(None, ort_inputs)

predictions = self.postprocess(output[0], input_shape, self.with_p6)[0]

boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]

boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
boxes_xyxy /= ratio
dets = self.multiclass_nms(
boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)

if dets is None:
top1_bbox_str = str(0) + ',' + str(img_w) + ',' + str(
0) + ',' + str(img_h)
crop_img = img_input.copy()
coord = top1_bbox_str
else:
bboxes = dets[:, :5]
labels = dets[:, 5]

if not self.multi_detect:
cid = int(cid)
if (not sub_class):
if cid > -1:
if cid == 0: # cloth
cid_ind1 = np.where(labels < 3)
cid_ind2 = np.where(labels == 9)
cid_ind = np.hstack((cid_ind1[0], cid_ind2[0]))
scores = bboxes[cid_ind, -1] # 0, 1, 2, 9

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 3: # bag
cid_ind = np.where(labels == 3)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 4: # shoe
cid_ind = np.where(labels == 4)
scores = bboxes[cid_ind, -1] # 4

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

else: # other
cid_ind5 = np.where(labels == 5)
cid_ind6 = np.where(labels == 6)
cid_ind7 = np.where(labels == 7)
cid_ind8 = np.where(labels == 8)
cid_ind10 = np.where(labels == 10)
cid_ind11 = np.where(labels == 11)
cid_ind12 = np.where(labels == 12)
cid_ind = np.hstack(
(cid_ind5[0], cid_ind6[0], cid_ind7[0],
cid_ind8[0], cid_ind10[0], cid_ind11[0],
cid_ind12[0]))
scores = bboxes[cid_ind, -1] # 5,6,7,8,10,11,12

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

else:
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)
else:
if cid > -1:
if cid == 0: # upper
cid_ind = np.where(labels == 0)

scores = bboxes[cid_ind, -1] # 0, 1, 2, 9

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 1: # skirt
cid_ind = np.where(labels == 1)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 2: # lower
cid_ind = np.where(labels == 2)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 3: # bag
cid_ind = np.where(labels == 3)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 4: # shoe
cid_ind = np.where(labels == 4)
scores = bboxes[cid_ind, -1] # 4

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 5: # access
cid_ind = np.where(labels == 5)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 7: # beauty
cid_ind = np.where(labels == 6)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 9: # furniture
cid_ind = np.where(labels == 8)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

elif cid == 21: # underwear
cid_ind = np.where(labels == 9)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)
elif cid == 22: # digital
cid_ind = np.where(labels == 11)
scores = bboxes[cid_ind, -1] # 3

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

else: # other
cid_ind5 = np.where(labels == 7) # bottle
cid_ind6 = np.where(labels == 10) # toy
cid_ind7 = np.where(labels == 12) # toy
cid_ind = np.hstack(
(cid_ind5[0], cid_ind6[0], cid_ind7[0]))
scores = bboxes[cid_ind, -1] # 5,6,7

if scores.size > 0:

bboxes = bboxes[cid_ind]
labels = labels[cid_ind]
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)

else:
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)
else:
bboxes, labels = self.get_new_box_order(
bboxes, labels, img_h, img_w)
top1_bbox = bboxes[0].astype(np.int32)
top1_bbox[0] = min(max(0, top1_bbox[0]), img_input.shape[1] - 1)
top1_bbox[1] = min(max(0, top1_bbox[1]), img_input.shape[0] - 1)
top1_bbox[2] = max(min(img_input.shape[1] - 1, top1_bbox[2]), 0)
top1_bbox[3] = max(min(img_input.shape[0] - 1, top1_bbox[3]), 0)
if not self.multi_detect:

top1_bbox_str = str(top1_bbox[0]) + ',' + str(
top1_bbox[2]) + ',' + str(top1_bbox[1]) + ',' + str(
top1_bbox[3]) # x1, x2, y1, y2
crop_img = img_input[top1_bbox[1]:top1_bbox[3],
top1_bbox[0]:top1_bbox[2], :]
coord = top1_bbox_str
coord = ''
for i in range(0, len(bboxes)):
top_bbox = bboxes[i].astype(np.int32)
top_bbox[0] = min(
max(0, top_bbox[0]), img_input.shape[1] - 1)
top_bbox[1] = min(
max(0, top_bbox[1]), img_input.shape[0] - 1)
top_bbox[2] = max(
min(img_input.shape[1] - 1, top_bbox[2]), 0)
top_bbox[3] = max(
min(img_input.shape[0] - 1, top_bbox[3]), 0)
coord = coord + str(top_bbox[0]) + ',' + str(
top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str(
top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str(
bboxes[i][5]) + ';'

else:
coord = ''
for i in range(0, len(bboxes)):
top_bbox = bboxes[i].astype(np.int32)
top_bbox[0] = min(
max(0, top_bbox[0]), img_input.shape[1] - 1)
top_bbox[1] = min(
max(0, top_bbox[1]), img_input.shape[0] - 1)
top_bbox[2] = max(
min(img_input.shape[1] - 1, top_bbox[2]), 0)
top_bbox[3] = max(
min(img_input.shape[0] - 1, top_bbox[3]), 0)
coord = coord + str(top_bbox[0]) + ',' + str(
top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str(
top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str(
bboxes[i][5]) + ';' # x1, x2, y1, y2, conf
crop_img = img_input[top1_bbox[1]:top1_bbox[3],
top1_bbox[0]:top1_bbox[2], :]

crop_img = cv2.resize(crop_img, (224, 224))

return coord, crop_img # return top1 image and coord

+ 157
- 0
modelscope/models/cv/product_retrieval_embedding/item_embedding.py View File

@@ -0,0 +1,157 @@
import os
import time

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def gn_init(m, zero_init=False):
assert isinstance(m, nn.GroupNorm)
m.weight.data.fill_(0. if zero_init else 1.)
m.bias.data.zero_()


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
"""Bottleneck for resnet-style networks
Args:
inplanes: input channel number
planes: output channel number
"""
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.GroupNorm(32, planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn2 = nn.GroupNorm(32, planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.GroupNorm(32, planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

gn_init(self.bn1)
gn_init(self.bn2)
gn_init(self.bn3, zero_init=True)

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class ResNet(nn.Module):
"""
resnet-style network with group normalization
"""

def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.GroupNorm(32, 64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1)

self.gap = nn.AvgPool2d((14, 14))
self.reduce_conv = nn.Conv2d(2048, 512, kernel_size=1)

gn_init(self.bn1)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.AvgPool2d(stride, stride),
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=1,
bias=False),
nn.GroupNorm(32, planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.gap(x)
x = self.reduce_conv(x) # 512

x = x.view(x.size(0), -1) # 512
return F.normalize(x, p=2, dim=1)


def preprocess(img):
"""
preprocess the image with cv2-bgr style to tensor
"""
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

img_size = 224
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_new = cv2.resize(
img, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
content = np.array(img_new).astype(np.float32)
content = (content / 255.0 - mean) / std
# transpose
img_new = content.transpose(2, 0, 1)
img_new = img_new[np.newaxis, :, :, :]
return img_new


def resnet50_embed():
"""
create resnet50 network with group normalization
"""
net = ResNet(Bottleneck, [3, 4, 6, 3])
return net

+ 115
- 0
modelscope/models/cv/product_retrieval_embedding/item_model.py View File

@@ -0,0 +1,115 @@
import os.path as osp
from typing import Any, Dict

import numpy as np
import torch

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.product_retrieval_embedding.item_detection import \
YOLOXONNX
from modelscope.models.cv.product_retrieval_embedding.item_embedding import (
preprocess, resnet50_embed)
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import create_device

logger = get_logger()

__all__ = ['ProductRetrievalEmbedding']


@MODELS.register_module(
Tasks.product_retrieval_embedding,
module_name=Models.product_retrieval_embedding)
class ProductRetrievalEmbedding(TorchModel):

def __init__(self, model_dir, device='cpu', **kwargs):
super().__init__(model_dir=model_dir, device=device, **kwargs)

def filter_param(src_params, own_state):
copied_keys = []
for name, param in src_params.items():
if 'module.' == name[0:7]:
name = name[7:]
if '.module.' not in list(own_state.keys())[0]:
name = name.replace('.module.', '.')
if (name in own_state) and (own_state[name].shape
== param.shape):
own_state[name].copy_(param)
copied_keys.append(name)

def load_pretrained(model, src_params):
if 'state_dict' in src_params:
src_params = src_params['state_dict']
own_state = model.state_dict()
filter_param(src_params, own_state)
model.load_state_dict(own_state)

cpu_flag = device == 'cpu'
self.device = create_device(
cpu_flag) # device.type == "cpu" or device.type == "cuda"
self.use_gpu = self.device.type == 'cuda'

# config the model path
self.local_model_dir = model_dir

# init feat model
self.preprocess_for_embed = preprocess # input is cv2 bgr format
model_feat = resnet50_embed()
src_params = torch.load(
osp.join(self.local_model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
'cpu')
load_pretrained(model_feat, src_params)
if self.use_gpu:
model_feat.to(self.device)
logger.info('Use GPU: {}'.format(self.device))
else:
logger.info('Use CPU for inference')

self.model_feat = model_feat

# init det model
self.model_det = YOLOXONNX(
onnx_path=osp.join(self.local_model_dir, 'onnx_detection.onnx'),
multi_detect=False)
logger.info('load model done')

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""
detection and feature extraction for input product image
"""
# input should be cv2 bgr format
assert 'img' in input.keys()

def set_phase(model, is_train):
if is_train:
model.train()
else:
model.eval()

is_train = False
set_phase(self.model_feat, is_train)
img = input['img'] # for detection
cid = '3' # preprocess detection category bag
# transform img(tensor) to numpy array with bgr
if isinstance(img, torch.Tensor):
img = img.data.cpu().numpy()
res, crop_img = self.model_det.forward(img,
cid) # detect with bag category
crop_img = self.preprocess_for_embed(crop_img) # feat preprocess
input_tensor = torch.from_numpy(crop_img.astype(np.float32))
device = next(self.model_feat.parameters()).device
use_gpu = device.type == 'cuda'
with torch.no_grad():
if use_gpu:
input_tensor = input_tensor.to(device)
out_embedding = self.model_feat(input_tensor)
out_embedding = out_embedding.cpu().numpy()[
0, :] # feature array with 512 elements

output = {OutputKeys.IMG_EMBEDDING: None}
output[OutputKeys.IMG_EMBEDDING] = out_embedding
return output

+ 6
- 0
modelscope/outputs.py View File

@@ -140,6 +140,12 @@ TASK_OUTPUTS = {
# }
Tasks.ocr_detection: [OutputKeys.POLYGONS],

# image embedding result for a single image
# {
# "image_bedding": np.array with shape [D]
# }
Tasks.product_retrieval_embedding: [OutputKeys.IMG_EMBEDDING],

# video embedding result for single video
# {
# "video_embedding": np.array with shape [D],


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -109,6 +109,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_gan_face-image-generation'),
Tasks.image_super_resolution: (Pipelines.image_super_resolution,
'damo/cv_rrdb_image-super-resolution'),
Tasks.product_retrieval_embedding:
(Pipelines.product_retrieval_embedding,
'damo/cv_resnet50_product-bag-embedding-models'),
Tasks.image_classification_imagenet:
(Pipelines.general_image_classification,
'damo/cv_vit-base_image-classification_ImageNet-labels'),


+ 4
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .face_detection_pipeline import FaceDetectionPipeline
from .face_recognition_pipeline import FaceRecognitionPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline
from .image_classification_pipeline import ImageClassificationPipeline
from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_classification_pipeline import GeneralImageClassificationPipeline
from .image_denoise_pipeline import ImageDenoisePipeline
@@ -20,12 +21,12 @@ if TYPE_CHECKING:
from .image_matting_pipeline import ImageMattingPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline
from .style_transfer_pipeline import StyleTransferPipeline
from .live_category_pipeline import LiveCategoryPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
from .video_category_pipeline import VideoCategoryPipeline
from .virtual_tryon_pipeline import VirtualTryonPipeline
from .image_classification_pipeline import ImageClassificationPipeline
else:
_import_structure = {
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
@@ -47,6 +48,8 @@ else:
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
'image_to_image_translation_pipeline':
['Image2ImageTranslationPipeline'],
'product_retrieval_embedding_pipeline':
['ProductRetrievalEmbeddingPipeline'],
'live_category_pipeline': ['LiveCategoryPipeline'],
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
'style_transfer_pipeline': ['StyleTransferPipeline'],


+ 45
- 0
modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py View File

@@ -0,0 +1,45 @@
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.product_retrieval_embedding,
module_name=Pipelines.product_retrieval_embedding)
class ProductRetrievalEmbeddingPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""use `model` to create a pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)

def preprocess(self, input: Input) -> Dict[str, Any]:
"""
preprocess the input image to cv2-bgr style
"""
img = LoadImage.convert_to_ndarray(input) # array with rgb
img = np.ascontiguousarray(img[:, :, ::-1]) # array with bgr
result = {'img': img} # only for detection
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
return self.model(input)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -37,6 +37,7 @@ class CVTasks(object):
face_image_generation = 'face-image-generation'
image_super_resolution = 'image-super-resolution'
style_transfer = 'style-transfer'
product_retrieval_embedding = 'product-retrieval-embedding'
live_category = 'live-category'
video_category = 'video-category'
image_classification_imagenet = 'image-classification-imagenet'


+ 3
- 0
requirements/cv.txt View File

@@ -1,5 +1,8 @@
decord>=0.6.0
easydict
# tensorflow 1.x compatability requires numpy version to be cap at 1.18
numpy<=1.18
onnxruntime>=1.10
tf_slim
timm
torchvision

+ 2
- 1
requirements/runtime.txt View File

@@ -4,7 +4,8 @@ easydict
einops
filelock>=3.3.0
gast>=0.2.2
numpy
# tensorflow 1.x compatability requires numpy version to be cap at 1.18
numpy<=1.18
opencv-python
oss2
Pillow>=6.2.0


+ 39
- 0
tests/pipelines/test_product_retrieval_embedding.py View File

@@ -0,0 +1,39 @@
import unittest

import numpy as np

from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class ProductRetrievalEmbeddingTest(unittest.TestCase):
model_id = 'damo/cv_resnet50_product-bag-embedding-models'
img_input = 'data/test/images/product_embed_bag.jpg'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_name(self):
product_embed = pipeline(Tasks.product_retrieval_embedding,
self.model_id)
result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING]
print('abs sum value is: {}'.format(np.sum(np.abs(result))))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
product_embed = pipeline(
task=Tasks.product_retrieval_embedding, model=model)
result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING]
print('abs sum value is: {}'.format(np.sum(np.abs(result))))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
product_embed = pipeline(task=Tasks.product_retrieval_embedding)
result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING]
print('abs sum value is: {}'.format(np.sum(np.abs(result))))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save