From e93339ea877b93fa0c1b9ebfeee8877f78facb0e Mon Sep 17 00:00:00 2001 From: "dangwei.ldw" Date: Mon, 1 Aug 2022 17:53:22 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]Merge=20request=20from=20?= =?UTF-8?q?=E4=BB=B2=E7=90=86:feat/product=5Ffeature=20=20=20=20=20=20=20?= =?UTF-8?q?=20=20Link:=20https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/co?= =?UTF-8?q?dereview/9515599?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/test/images/product_embed_bag.jpg | 3 + modelscope/metainfo.py | 2 + modelscope/models/cv/__init__.py | 4 +- .../product_retrieval_embedding/__init__.py | 23 + .../item_detection.py | 517 ++++++++++++++++++ .../item_embedding.py | 157 ++++++ .../product_retrieval_embedding/item_model.py | 115 ++++ modelscope/outputs.py | 6 + modelscope/pipelines/builder.py | 3 + modelscope/pipelines/cv/__init__.py | 5 +- .../product_retrieval_embedding_pipeline.py | 45 ++ modelscope/utils/constant.py | 1 + requirements/cv.txt | 3 + requirements/runtime.txt | 3 +- .../test_product_retrieval_embedding.py | 39 ++ 15 files changed, 922 insertions(+), 4 deletions(-) create mode 100644 data/test/images/product_embed_bag.jpg create mode 100644 modelscope/models/cv/product_retrieval_embedding/__init__.py create mode 100644 modelscope/models/cv/product_retrieval_embedding/item_detection.py create mode 100644 modelscope/models/cv/product_retrieval_embedding/item_embedding.py create mode 100644 modelscope/models/cv/product_retrieval_embedding/item_model.py create mode 100644 modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py create mode 100644 tests/pipelines/test_product_retrieval_embedding.py diff --git a/data/test/images/product_embed_bag.jpg b/data/test/images/product_embed_bag.jpg new file mode 100644 index 00000000..8427c028 --- /dev/null +++ b/data/test/images/product_embed_bag.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08691a9373aa6d05b236a4ba788f3eccdea4c37aa77b30fc94b02ec3e1f18210 +size 367017 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 75259f43..ec3ffc04 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -16,6 +16,7 @@ class Models(object): nafnet = 'nafnet' csrnet = 'csrnet' cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' + product_retrieval_embedding = 'product-retrieval-embedding' # nlp models bert = 'bert' @@ -84,6 +85,7 @@ class Pipelines(object): image_super_resolution = 'rrdb-image-super-resolution' face_image_generation = 'gan-face-image-generation' style_transfer = 'AAMS-style-transfer' + product_retrieval_embedding = 'resnet50-product-retrieval-embedding' face_recognition = 'ir101-face-recognition-cfglint' image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' image2image_translation = 'image-to-image-translation' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index a96c6370..f5f12471 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -3,5 +3,5 @@ from . import (action_recognition, animal_recognition, cartoon, cmdssl_video_embedding, face_detection, face_generation, image_classification, image_color_enhance, image_colorization, image_denoise, image_instance_segmentation, - image_to_image_translation, object_detection, super_resolution, - virual_tryon) + image_to_image_translation, object_detection, + product_retrieval_embedding, super_resolution, virual_tryon) diff --git a/modelscope/models/cv/product_retrieval_embedding/__init__.py b/modelscope/models/cv/product_retrieval_embedding/__init__.py new file mode 100644 index 00000000..7a02a60f --- /dev/null +++ b/modelscope/models/cv/product_retrieval_embedding/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .item_model import ProductRetrievalEmbedding + +else: + _import_structure = { + 'item_model': ['ProductRetrievalEmbedding'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/product_retrieval_embedding/item_detection.py b/modelscope/models/cv/product_retrieval_embedding/item_detection.py new file mode 100644 index 00000000..4dd2914b --- /dev/null +++ b/modelscope/models/cv/product_retrieval_embedding/item_detection.py @@ -0,0 +1,517 @@ +import cv2 +import numpy as np + + +class YOLOXONNX(object): + """ + Product detection model with onnx inference + """ + + def __init__(self, onnx_path, multi_detect=False): + """Create product detection model + Args: + onnx_path: onnx model path for product detection + multi_detect: detection parameter, should be set as False + + """ + self.input_reso = 416 + self.iou_thr = 0.45 + self.score_thr = 0.3 + self.img_shape = tuple([self.input_reso, self.input_reso, 3]) + self.num_classes = 13 + self.onnx_path = onnx_path + import onnxruntime as ort + self.ort_session = ort.InferenceSession(self.onnx_path) + self.with_p6 = False + self.multi_detect = multi_detect + + def format_judge(self, img): + m_min_width = 100 + m_min_height = 100 + + height, width, c = img.shape + + if width * height > 1024 * 1024: + if height > width: + long_side = height + short_side = width + long_ratio = float(long_side) / 1024.0 + short_ratio = float(short_side) / float(m_min_width) + else: + long_side = width + short_side = height + long_ratio = float(long_side) / 1024.0 + short_ratio = float(short_side) / float(m_min_height) + + if long_side == height: + if long_ratio < short_ratio: + height_new = 1024 + width_new = (int)((1024 * width) / height) + + img_res = cv2.resize(img, (width_new, height_new), + cv2.INTER_LINEAR) + else: + height_new = (int)((m_min_width * height) / width) + width_new = m_min_width + + img_res = cv2.resize(img, (width_new, height_new), + cv2.INTER_LINEAR) + + elif long_side == width: + if long_ratio < short_ratio: + height_new = (int)((1024 * height) / width) + width_new = 1024 + + img_res = cv2.resize(img, (width_new, height_new), + cv2.INTER_LINEAR) + else: + width_new = (int)((m_min_height * width) / height) + height_new = m_min_height + + img_res = cv2.resize(img, (width_new, height_new), + cv2.INTER_LINEAR) + else: + img_res = img + + return img_res + + def preprocess(self, image, input_size, swap=(2, 0, 1)): + """ + Args: + image, cv2 image with BGR format + input_size, model input size + """ + if len(image.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 + else: + padded_img = np.ones(input_size) * 114.0 + img = np.array(image) + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.float32) + padded_img[:int(img.shape[0] * r), :int(img.shape[1] + * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + def cal_iou(self, val1, val2): + x11, y11, x12, y12 = val1 + x21, y21, x22, y22 = val2 + + leftX = max(x11, x21) + topY = max(y11, y21) + rightX = min(x12, x22) + bottomY = min(y12, y22) + if rightX < leftX or bottomY < topY: + return 0 + area = float((rightX - leftX) * (bottomY - topY)) + barea = (x12 - x11) * (y12 - y11) + (x22 - x21) * (y22 - y21) - area + if barea <= 0: + return 0 + return area / barea + + def nms(self, boxes, scores, nms_thr): + """ + Single class NMS implemented in Numpy. + """ + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + + def multiclass_nms(self, boxes, scores, nms_thr, score_thr): + """ + Multiclass NMS implemented in Numpy + """ + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = self.nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate([ + valid_boxes[keep], valid_scores[keep, None], cls_inds + ], 1) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + + def postprocess(self, outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + + if not p6: + strides = [8, 16, 32] + else: + strides = [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + + def get_new_box_order(self, bboxes, labels, img_h, img_w): + """ + refine bbox score + """ + bboxes = np.hstack((bboxes, np.zeros((bboxes.shape[0], 1)))) + scores = bboxes[:, 4] + order = scores.argsort()[::-1] + bboxes_temp = bboxes[order] + labels_temp = labels[order] + bboxes = np.empty((0, 6)) + # import pdb;pdb.set_trace() + bboxes = np.vstack((bboxes, bboxes_temp[0].tolist())) + labels = np.empty((0, )) + + labels = np.hstack((labels, [labels_temp[0]])) + for i in range(1, bboxes_temp.shape[0]): + iou_max = 0 + for j in range(bboxes.shape[0]): + iou_temp = self.cal_iou(bboxes_temp[i][:4], bboxes[j][:4]) + if (iou_temp > iou_max): + iou_max = iou_temp + if (iou_max < 0.45): + bboxes = np.vstack((bboxes, bboxes_temp[i].tolist())) + labels = np.hstack((labels, [labels_temp[i]])) + + num_03 = scores > 0.3 + num_03 = num_03.sum() + num_out = max(num_03, 1) + bboxes = bboxes[:num_out, :] + labels = labels[:num_out] + + return bboxes, labels + + def forward(self, img_input, cid='0', sub_class=False): + """ + forward for product detection + """ + input_shape = self.img_shape + + img, ratio = self.preprocess(img_input, input_shape) + img_h, img_w = img_input.shape[:2] + + ort_inputs = { + self.ort_session.get_inputs()[0].name: img[None, :, :, :] + } + + output = self.ort_session.run(None, ort_inputs) + + predictions = self.postprocess(output[0], input_shape, self.with_p6)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2. + boxes_xyxy /= ratio + dets = self.multiclass_nms( + boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + + if dets is None: + top1_bbox_str = str(0) + ',' + str(img_w) + ',' + str( + 0) + ',' + str(img_h) + crop_img = img_input.copy() + coord = top1_bbox_str + else: + bboxes = dets[:, :5] + labels = dets[:, 5] + + if not self.multi_detect: + cid = int(cid) + if (not sub_class): + if cid > -1: + if cid == 0: # cloth + cid_ind1 = np.where(labels < 3) + cid_ind2 = np.where(labels == 9) + cid_ind = np.hstack((cid_ind1[0], cid_ind2[0])) + scores = bboxes[cid_ind, -1] # 0, 1, 2, 9 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 3: # bag + cid_ind = np.where(labels == 3) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 4: # shoe + cid_ind = np.where(labels == 4) + scores = bboxes[cid_ind, -1] # 4 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + else: # other + cid_ind5 = np.where(labels == 5) + cid_ind6 = np.where(labels == 6) + cid_ind7 = np.where(labels == 7) + cid_ind8 = np.where(labels == 8) + cid_ind10 = np.where(labels == 10) + cid_ind11 = np.where(labels == 11) + cid_ind12 = np.where(labels == 12) + cid_ind = np.hstack( + (cid_ind5[0], cid_ind6[0], cid_ind7[0], + cid_ind8[0], cid_ind10[0], cid_ind11[0], + cid_ind12[0])) + scores = bboxes[cid_ind, -1] # 5,6,7,8,10,11,12 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + else: + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + else: + if cid > -1: + if cid == 0: # upper + cid_ind = np.where(labels == 0) + + scores = bboxes[cid_ind, -1] # 0, 1, 2, 9 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 1: # skirt + cid_ind = np.where(labels == 1) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 2: # lower + cid_ind = np.where(labels == 2) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 3: # bag + cid_ind = np.where(labels == 3) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 4: # shoe + cid_ind = np.where(labels == 4) + scores = bboxes[cid_ind, -1] # 4 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 5: # access + cid_ind = np.where(labels == 5) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 7: # beauty + cid_ind = np.where(labels == 6) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 9: # furniture + cid_ind = np.where(labels == 8) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 21: # underwear + cid_ind = np.where(labels == 9) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + elif cid == 22: # digital + cid_ind = np.where(labels == 11) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + else: # other + cid_ind5 = np.where(labels == 7) # bottle + cid_ind6 = np.where(labels == 10) # toy + cid_ind7 = np.where(labels == 12) # toy + cid_ind = np.hstack( + (cid_ind5[0], cid_ind6[0], cid_ind7[0])) + scores = bboxes[cid_ind, -1] # 5,6,7 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + else: + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + else: + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + top1_bbox = bboxes[0].astype(np.int32) + top1_bbox[0] = min(max(0, top1_bbox[0]), img_input.shape[1] - 1) + top1_bbox[1] = min(max(0, top1_bbox[1]), img_input.shape[0] - 1) + top1_bbox[2] = max(min(img_input.shape[1] - 1, top1_bbox[2]), 0) + top1_bbox[3] = max(min(img_input.shape[0] - 1, top1_bbox[3]), 0) + if not self.multi_detect: + + top1_bbox_str = str(top1_bbox[0]) + ',' + str( + top1_bbox[2]) + ',' + str(top1_bbox[1]) + ',' + str( + top1_bbox[3]) # x1, x2, y1, y2 + crop_img = img_input[top1_bbox[1]:top1_bbox[3], + top1_bbox[0]:top1_bbox[2], :] + coord = top1_bbox_str + coord = '' + for i in range(0, len(bboxes)): + top_bbox = bboxes[i].astype(np.int32) + top_bbox[0] = min( + max(0, top_bbox[0]), img_input.shape[1] - 1) + top_bbox[1] = min( + max(0, top_bbox[1]), img_input.shape[0] - 1) + top_bbox[2] = max( + min(img_input.shape[1] - 1, top_bbox[2]), 0) + top_bbox[3] = max( + min(img_input.shape[0] - 1, top_bbox[3]), 0) + coord = coord + str(top_bbox[0]) + ',' + str( + top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str( + top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str( + bboxes[i][5]) + ';' + + else: + coord = '' + for i in range(0, len(bboxes)): + top_bbox = bboxes[i].astype(np.int32) + top_bbox[0] = min( + max(0, top_bbox[0]), img_input.shape[1] - 1) + top_bbox[1] = min( + max(0, top_bbox[1]), img_input.shape[0] - 1) + top_bbox[2] = max( + min(img_input.shape[1] - 1, top_bbox[2]), 0) + top_bbox[3] = max( + min(img_input.shape[0] - 1, top_bbox[3]), 0) + coord = coord + str(top_bbox[0]) + ',' + str( + top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str( + top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str( + bboxes[i][5]) + ';' # x1, x2, y1, y2, conf + crop_img = img_input[top1_bbox[1]:top1_bbox[3], + top1_bbox[0]:top1_bbox[2], :] + + crop_img = cv2.resize(crop_img, (224, 224)) + + return coord, crop_img # return top1 image and coord diff --git a/modelscope/models/cv/product_retrieval_embedding/item_embedding.py b/modelscope/models/cv/product_retrieval_embedding/item_embedding.py new file mode 100644 index 00000000..b01031d5 --- /dev/null +++ b/modelscope/models/cv/product_retrieval_embedding/item_embedding.py @@ -0,0 +1,157 @@ +import os +import time + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def gn_init(m, zero_init=False): + assert isinstance(m, nn.GroupNorm) + m.weight.data.fill_(0. if zero_init else 1.) + m.bias.data.zero_() + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + """Bottleneck for resnet-style networks + Args: + inplanes: input channel number + planes: output channel number + """ + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.GroupNorm(32, planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.GroupNorm(32, planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.GroupNorm(32, planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + gn_init(self.bn1) + gn_init(self.bn2) + gn_init(self.bn3, zero_init=True) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + """ + resnet-style network with group normalization + """ + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.GroupNorm(32, 64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1) + + self.gap = nn.AvgPool2d((14, 14)) + self.reduce_conv = nn.Conv2d(2048, 512, kernel_size=1) + + gn_init(self.bn1) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d(stride, stride), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=1, + bias=False), + nn.GroupNorm(32, planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.gap(x) + x = self.reduce_conv(x) # 512 + + x = x.view(x.size(0), -1) # 512 + return F.normalize(x, p=2, dim=1) + + +def preprocess(img): + """ + preprocess the image with cv2-bgr style to tensor + """ + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + + img_size = 224 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_new = cv2.resize( + img, (img_size, img_size), interpolation=cv2.INTER_LINEAR) + content = np.array(img_new).astype(np.float32) + content = (content / 255.0 - mean) / std + # transpose + img_new = content.transpose(2, 0, 1) + img_new = img_new[np.newaxis, :, :, :] + return img_new + + +def resnet50_embed(): + """ + create resnet50 network with group normalization + """ + net = ResNet(Bottleneck, [3, 4, 6, 3]) + return net diff --git a/modelscope/models/cv/product_retrieval_embedding/item_model.py b/modelscope/models/cv/product_retrieval_embedding/item_model.py new file mode 100644 index 00000000..2a893669 --- /dev/null +++ b/modelscope/models/cv/product_retrieval_embedding/item_model.py @@ -0,0 +1,115 @@ +import os.path as osp +from typing import Any, Dict + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.product_retrieval_embedding.item_detection import \ + YOLOXONNX +from modelscope.models.cv.product_retrieval_embedding.item_embedding import ( + preprocess, resnet50_embed) +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import create_device + +logger = get_logger() + +__all__ = ['ProductRetrievalEmbedding'] + + +@MODELS.register_module( + Tasks.product_retrieval_embedding, + module_name=Models.product_retrieval_embedding) +class ProductRetrievalEmbedding(TorchModel): + + def __init__(self, model_dir, device='cpu', **kwargs): + super().__init__(model_dir=model_dir, device=device, **kwargs) + + def filter_param(src_params, own_state): + copied_keys = [] + for name, param in src_params.items(): + if 'module.' == name[0:7]: + name = name[7:] + if '.module.' not in list(own_state.keys())[0]: + name = name.replace('.module.', '.') + if (name in own_state) and (own_state[name].shape + == param.shape): + own_state[name].copy_(param) + copied_keys.append(name) + + def load_pretrained(model, src_params): + if 'state_dict' in src_params: + src_params = src_params['state_dict'] + own_state = model.state_dict() + filter_param(src_params, own_state) + model.load_state_dict(own_state) + + cpu_flag = device == 'cpu' + self.device = create_device( + cpu_flag) # device.type == "cpu" or device.type == "cuda" + self.use_gpu = self.device.type == 'cuda' + + # config the model path + self.local_model_dir = model_dir + + # init feat model + self.preprocess_for_embed = preprocess # input is cv2 bgr format + model_feat = resnet50_embed() + src_params = torch.load( + osp.join(self.local_model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + 'cpu') + load_pretrained(model_feat, src_params) + if self.use_gpu: + model_feat.to(self.device) + logger.info('Use GPU: {}'.format(self.device)) + else: + logger.info('Use CPU for inference') + + self.model_feat = model_feat + + # init det model + self.model_det = YOLOXONNX( + onnx_path=osp.join(self.local_model_dir, 'onnx_detection.onnx'), + multi_detect=False) + logger.info('load model done') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """ + detection and feature extraction for input product image + """ + # input should be cv2 bgr format + assert 'img' in input.keys() + + def set_phase(model, is_train): + if is_train: + model.train() + else: + model.eval() + + is_train = False + set_phase(self.model_feat, is_train) + img = input['img'] # for detection + cid = '3' # preprocess detection category bag + # transform img(tensor) to numpy array with bgr + if isinstance(img, torch.Tensor): + img = img.data.cpu().numpy() + res, crop_img = self.model_det.forward(img, + cid) # detect with bag category + crop_img = self.preprocess_for_embed(crop_img) # feat preprocess + input_tensor = torch.from_numpy(crop_img.astype(np.float32)) + device = next(self.model_feat.parameters()).device + use_gpu = device.type == 'cuda' + with torch.no_grad(): + if use_gpu: + input_tensor = input_tensor.to(device) + out_embedding = self.model_feat(input_tensor) + out_embedding = out_embedding.cpu().numpy()[ + 0, :] # feature array with 512 elements + + output = {OutputKeys.IMG_EMBEDDING: None} + output[OutputKeys.IMG_EMBEDDING] = out_embedding + return output diff --git a/modelscope/outputs.py b/modelscope/outputs.py index dee31a4f..10333855 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -140,6 +140,12 @@ TASK_OUTPUTS = { # } Tasks.ocr_detection: [OutputKeys.POLYGONS], + # image embedding result for a single image + # { + # "image_bedding": np.array with shape [D] + # } + Tasks.product_retrieval_embedding: [OutputKeys.IMG_EMBEDDING], + # video embedding result for single video # { # "video_embedding": np.array with shape [D], diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index a0e5b5af..50652ac1 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -109,6 +109,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_gan_face-image-generation'), Tasks.image_super_resolution: (Pipelines.image_super_resolution, 'damo/cv_rrdb_image-super-resolution'), + Tasks.product_retrieval_embedding: + (Pipelines.product_retrieval_embedding, + 'damo/cv_resnet50_product-bag-embedding-models'), Tasks.image_classification_imagenet: (Pipelines.general_image_classification, 'damo/cv_vit-base_image-classification_ImageNet-labels'), diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 35230f08..e66176e4 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from .face_detection_pipeline import FaceDetectionPipeline from .face_recognition_pipeline import FaceRecognitionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline + from .image_classification_pipeline import ImageClassificationPipeline from .image_cartoon_pipeline import ImageCartoonPipeline from .image_classification_pipeline import GeneralImageClassificationPipeline from .image_denoise_pipeline import ImageDenoisePipeline @@ -20,12 +21,12 @@ if TYPE_CHECKING: from .image_matting_pipeline import ImageMattingPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline + from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline from .style_transfer_pipeline import StyleTransferPipeline from .live_category_pipeline import LiveCategoryPipeline from .ocr_detection_pipeline import OCRDetectionPipeline from .video_category_pipeline import VideoCategoryPipeline from .virtual_tryon_pipeline import VirtualTryonPipeline - from .image_classification_pipeline import ImageClassificationPipeline else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], @@ -47,6 +48,8 @@ else: 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], 'image_to_image_translation_pipeline': ['Image2ImageTranslationPipeline'], + 'product_retrieval_embedding_pipeline': + ['ProductRetrievalEmbeddingPipeline'], 'live_category_pipeline': ['LiveCategoryPipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'style_transfer_pipeline': ['StyleTransferPipeline'], diff --git a/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py b/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py new file mode 100644 index 00000000..2614983b --- /dev/null +++ b/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py @@ -0,0 +1,45 @@ +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.product_retrieval_embedding, + module_name=Pipelines.product_retrieval_embedding) +class ProductRetrievalEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """use `model` to create a pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ + preprocess the input image to cv2-bgr style + """ + img = LoadImage.convert_to_ndarray(input) # array with rgb + img = np.ascontiguousarray(img[:, :, ::-1]) # array with bgr + result = {'img': img} # only for detection + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6bac48ee..ec829eaf 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -37,6 +37,7 @@ class CVTasks(object): face_image_generation = 'face-image-generation' image_super_resolution = 'image-super-resolution' style_transfer = 'style-transfer' + product_retrieval_embedding = 'product-retrieval-embedding' live_category = 'live-category' video_category = 'video-category' image_classification_imagenet = 'image-classification-imagenet' diff --git a/requirements/cv.txt b/requirements/cv.txt index a0f505c0..bd1a72db 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -1,5 +1,8 @@ decord>=0.6.0 easydict +# tensorflow 1.x compatability requires numpy version to be cap at 1.18 +numpy<=1.18 +onnxruntime>=1.10 tf_slim timm torchvision diff --git a/requirements/runtime.txt b/requirements/runtime.txt index fbf33854..491c4f21 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -4,7 +4,8 @@ easydict einops filelock>=3.3.0 gast>=0.2.2 -numpy +# tensorflow 1.x compatability requires numpy version to be cap at 1.18 +numpy<=1.18 opencv-python oss2 Pillow>=6.2.0 diff --git a/tests/pipelines/test_product_retrieval_embedding.py b/tests/pipelines/test_product_retrieval_embedding.py new file mode 100644 index 00000000..c0129ec5 --- /dev/null +++ b/tests/pipelines/test_product_retrieval_embedding.py @@ -0,0 +1,39 @@ +import unittest + +import numpy as np + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ProductRetrievalEmbeddingTest(unittest.TestCase): + model_id = 'damo/cv_resnet50_product-bag-embedding-models' + img_input = 'data/test/images/product_embed_bag.jpg' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_name(self): + product_embed = pipeline(Tasks.product_retrieval_embedding, + self.model_id) + result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] + print('abs sum value is: {}'.format(np.sum(np.abs(result)))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + product_embed = pipeline( + task=Tasks.product_retrieval_embedding, model=model) + result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] + print('abs sum value is: {}'.format(np.sum(np.abs(result)))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + product_embed = pipeline(task=Tasks.product_retrieval_embedding) + result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] + print('abs sum value is: {}'.format(np.sum(np.abs(result)))) + + +if __name__ == '__main__': + unittest.main()