Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9515599master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:08691a9373aa6d05b236a4ba788f3eccdea4c37aa77b30fc94b02ec3e1f18210 | |||
size 367017 |
@@ -16,6 +16,7 @@ class Models(object): | |||
nafnet = 'nafnet' | |||
csrnet = 'csrnet' | |||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | |||
product_retrieval_embedding = 'product-retrieval-embedding' | |||
# nlp models | |||
bert = 'bert' | |||
@@ -84,6 +85,7 @@ class Pipelines(object): | |||
image_super_resolution = 'rrdb-image-super-resolution' | |||
face_image_generation = 'gan-face-image-generation' | |||
style_transfer = 'AAMS-style-transfer' | |||
product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | |||
face_recognition = 'ir101-face-recognition-cfglint' | |||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
image2image_translation = 'image-to-image-translation' | |||
@@ -3,5 +3,5 @@ from . import (action_recognition, animal_recognition, cartoon, | |||
cmdssl_video_embedding, face_detection, face_generation, | |||
image_classification, image_color_enhance, image_colorization, | |||
image_denoise, image_instance_segmentation, | |||
image_to_image_translation, object_detection, super_resolution, | |||
virual_tryon) | |||
image_to_image_translation, object_detection, | |||
product_retrieval_embedding, super_resolution, virual_tryon) |
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .item_model import ProductRetrievalEmbedding | |||
else: | |||
_import_structure = { | |||
'item_model': ['ProductRetrievalEmbedding'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,517 @@ | |||
import cv2 | |||
import numpy as np | |||
class YOLOXONNX(object): | |||
""" | |||
Product detection model with onnx inference | |||
""" | |||
def __init__(self, onnx_path, multi_detect=False): | |||
"""Create product detection model | |||
Args: | |||
onnx_path: onnx model path for product detection | |||
multi_detect: detection parameter, should be set as False | |||
""" | |||
self.input_reso = 416 | |||
self.iou_thr = 0.45 | |||
self.score_thr = 0.3 | |||
self.img_shape = tuple([self.input_reso, self.input_reso, 3]) | |||
self.num_classes = 13 | |||
self.onnx_path = onnx_path | |||
import onnxruntime as ort | |||
self.ort_session = ort.InferenceSession(self.onnx_path) | |||
self.with_p6 = False | |||
self.multi_detect = multi_detect | |||
def format_judge(self, img): | |||
m_min_width = 100 | |||
m_min_height = 100 | |||
height, width, c = img.shape | |||
if width * height > 1024 * 1024: | |||
if height > width: | |||
long_side = height | |||
short_side = width | |||
long_ratio = float(long_side) / 1024.0 | |||
short_ratio = float(short_side) / float(m_min_width) | |||
else: | |||
long_side = width | |||
short_side = height | |||
long_ratio = float(long_side) / 1024.0 | |||
short_ratio = float(short_side) / float(m_min_height) | |||
if long_side == height: | |||
if long_ratio < short_ratio: | |||
height_new = 1024 | |||
width_new = (int)((1024 * width) / height) | |||
img_res = cv2.resize(img, (width_new, height_new), | |||
cv2.INTER_LINEAR) | |||
else: | |||
height_new = (int)((m_min_width * height) / width) | |||
width_new = m_min_width | |||
img_res = cv2.resize(img, (width_new, height_new), | |||
cv2.INTER_LINEAR) | |||
elif long_side == width: | |||
if long_ratio < short_ratio: | |||
height_new = (int)((1024 * height) / width) | |||
width_new = 1024 | |||
img_res = cv2.resize(img, (width_new, height_new), | |||
cv2.INTER_LINEAR) | |||
else: | |||
width_new = (int)((m_min_height * width) / height) | |||
height_new = m_min_height | |||
img_res = cv2.resize(img, (width_new, height_new), | |||
cv2.INTER_LINEAR) | |||
else: | |||
img_res = img | |||
return img_res | |||
def preprocess(self, image, input_size, swap=(2, 0, 1)): | |||
""" | |||
Args: | |||
image, cv2 image with BGR format | |||
input_size, model input size | |||
""" | |||
if len(image.shape) == 3: | |||
padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 | |||
else: | |||
padded_img = np.ones(input_size) * 114.0 | |||
img = np.array(image) | |||
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) | |||
resized_img = cv2.resize( | |||
img, | |||
(int(img.shape[1] * r), int(img.shape[0] * r)), | |||
interpolation=cv2.INTER_LINEAR, | |||
).astype(np.float32) | |||
padded_img[:int(img.shape[0] * r), :int(img.shape[1] | |||
* r)] = resized_img | |||
padded_img = padded_img.transpose(swap) | |||
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) | |||
return padded_img, r | |||
def cal_iou(self, val1, val2): | |||
x11, y11, x12, y12 = val1 | |||
x21, y21, x22, y22 = val2 | |||
leftX = max(x11, x21) | |||
topY = max(y11, y21) | |||
rightX = min(x12, x22) | |||
bottomY = min(y12, y22) | |||
if rightX < leftX or bottomY < topY: | |||
return 0 | |||
area = float((rightX - leftX) * (bottomY - topY)) | |||
barea = (x12 - x11) * (y12 - y11) + (x22 - x21) * (y22 - y21) - area | |||
if barea <= 0: | |||
return 0 | |||
return area / barea | |||
def nms(self, boxes, scores, nms_thr): | |||
""" | |||
Single class NMS implemented in Numpy. | |||
""" | |||
x1 = boxes[:, 0] | |||
y1 = boxes[:, 1] | |||
x2 = boxes[:, 2] | |||
y2 = boxes[:, 3] | |||
areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||
order = scores.argsort()[::-1] | |||
keep = [] | |||
while order.size > 0: | |||
i = order[0] | |||
keep.append(i) | |||
xx1 = np.maximum(x1[i], x1[order[1:]]) | |||
yy1 = np.maximum(y1[i], y1[order[1:]]) | |||
xx2 = np.minimum(x2[i], x2[order[1:]]) | |||
yy2 = np.minimum(y2[i], y2[order[1:]]) | |||
w = np.maximum(0.0, xx2 - xx1 + 1) | |||
h = np.maximum(0.0, yy2 - yy1 + 1) | |||
inter = w * h | |||
ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||
inds = np.where(ovr <= nms_thr)[0] | |||
order = order[inds + 1] | |||
return keep | |||
def multiclass_nms(self, boxes, scores, nms_thr, score_thr): | |||
""" | |||
Multiclass NMS implemented in Numpy | |||
""" | |||
final_dets = [] | |||
num_classes = scores.shape[1] | |||
for cls_ind in range(num_classes): | |||
cls_scores = scores[:, cls_ind] | |||
valid_score_mask = cls_scores > score_thr | |||
if valid_score_mask.sum() == 0: | |||
continue | |||
else: | |||
valid_scores = cls_scores[valid_score_mask] | |||
valid_boxes = boxes[valid_score_mask] | |||
keep = self.nms(valid_boxes, valid_scores, nms_thr) | |||
if len(keep) > 0: | |||
cls_inds = np.ones((len(keep), 1)) * cls_ind | |||
dets = np.concatenate([ | |||
valid_boxes[keep], valid_scores[keep, None], cls_inds | |||
], 1) | |||
final_dets.append(dets) | |||
if len(final_dets) == 0: | |||
return None | |||
return np.concatenate(final_dets, 0) | |||
def postprocess(self, outputs, img_size, p6=False): | |||
grids = [] | |||
expanded_strides = [] | |||
if not p6: | |||
strides = [8, 16, 32] | |||
else: | |||
strides = [8, 16, 32, 64] | |||
hsizes = [img_size[0] // stride for stride in strides] | |||
wsizes = [img_size[1] // stride for stride in strides] | |||
for hsize, wsize, stride in zip(hsizes, wsizes, strides): | |||
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) | |||
grid = np.stack((xv, yv), 2).reshape(1, -1, 2) | |||
grids.append(grid) | |||
shape = grid.shape[:2] | |||
expanded_strides.append(np.full((*shape, 1), stride)) | |||
grids = np.concatenate(grids, 1) | |||
expanded_strides = np.concatenate(expanded_strides, 1) | |||
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides | |||
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides | |||
return outputs | |||
def get_new_box_order(self, bboxes, labels, img_h, img_w): | |||
""" | |||
refine bbox score | |||
""" | |||
bboxes = np.hstack((bboxes, np.zeros((bboxes.shape[0], 1)))) | |||
scores = bboxes[:, 4] | |||
order = scores.argsort()[::-1] | |||
bboxes_temp = bboxes[order] | |||
labels_temp = labels[order] | |||
bboxes = np.empty((0, 6)) | |||
# import pdb;pdb.set_trace() | |||
bboxes = np.vstack((bboxes, bboxes_temp[0].tolist())) | |||
labels = np.empty((0, )) | |||
labels = np.hstack((labels, [labels_temp[0]])) | |||
for i in range(1, bboxes_temp.shape[0]): | |||
iou_max = 0 | |||
for j in range(bboxes.shape[0]): | |||
iou_temp = self.cal_iou(bboxes_temp[i][:4], bboxes[j][:4]) | |||
if (iou_temp > iou_max): | |||
iou_max = iou_temp | |||
if (iou_max < 0.45): | |||
bboxes = np.vstack((bboxes, bboxes_temp[i].tolist())) | |||
labels = np.hstack((labels, [labels_temp[i]])) | |||
num_03 = scores > 0.3 | |||
num_03 = num_03.sum() | |||
num_out = max(num_03, 1) | |||
bboxes = bboxes[:num_out, :] | |||
labels = labels[:num_out] | |||
return bboxes, labels | |||
def forward(self, img_input, cid='0', sub_class=False): | |||
""" | |||
forward for product detection | |||
""" | |||
input_shape = self.img_shape | |||
img, ratio = self.preprocess(img_input, input_shape) | |||
img_h, img_w = img_input.shape[:2] | |||
ort_inputs = { | |||
self.ort_session.get_inputs()[0].name: img[None, :, :, :] | |||
} | |||
output = self.ort_session.run(None, ort_inputs) | |||
predictions = self.postprocess(output[0], input_shape, self.with_p6)[0] | |||
boxes = predictions[:, :4] | |||
scores = predictions[:, 4:5] * predictions[:, 5:] | |||
boxes_xyxy = np.ones_like(boxes) | |||
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2. | |||
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2. | |||
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2. | |||
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2. | |||
boxes_xyxy /= ratio | |||
dets = self.multiclass_nms( | |||
boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) | |||
if dets is None: | |||
top1_bbox_str = str(0) + ',' + str(img_w) + ',' + str( | |||
0) + ',' + str(img_h) | |||
crop_img = img_input.copy() | |||
coord = top1_bbox_str | |||
else: | |||
bboxes = dets[:, :5] | |||
labels = dets[:, 5] | |||
if not self.multi_detect: | |||
cid = int(cid) | |||
if (not sub_class): | |||
if cid > -1: | |||
if cid == 0: # cloth | |||
cid_ind1 = np.where(labels < 3) | |||
cid_ind2 = np.where(labels == 9) | |||
cid_ind = np.hstack((cid_ind1[0], cid_ind2[0])) | |||
scores = bboxes[cid_ind, -1] # 0, 1, 2, 9 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 3: # bag | |||
cid_ind = np.where(labels == 3) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 4: # shoe | |||
cid_ind = np.where(labels == 4) | |||
scores = bboxes[cid_ind, -1] # 4 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
else: # other | |||
cid_ind5 = np.where(labels == 5) | |||
cid_ind6 = np.where(labels == 6) | |||
cid_ind7 = np.where(labels == 7) | |||
cid_ind8 = np.where(labels == 8) | |||
cid_ind10 = np.where(labels == 10) | |||
cid_ind11 = np.where(labels == 11) | |||
cid_ind12 = np.where(labels == 12) | |||
cid_ind = np.hstack( | |||
(cid_ind5[0], cid_ind6[0], cid_ind7[0], | |||
cid_ind8[0], cid_ind10[0], cid_ind11[0], | |||
cid_ind12[0])) | |||
scores = bboxes[cid_ind, -1] # 5,6,7,8,10,11,12 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
else: | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
else: | |||
if cid > -1: | |||
if cid == 0: # upper | |||
cid_ind = np.where(labels == 0) | |||
scores = bboxes[cid_ind, -1] # 0, 1, 2, 9 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 1: # skirt | |||
cid_ind = np.where(labels == 1) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 2: # lower | |||
cid_ind = np.where(labels == 2) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 3: # bag | |||
cid_ind = np.where(labels == 3) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 4: # shoe | |||
cid_ind = np.where(labels == 4) | |||
scores = bboxes[cid_ind, -1] # 4 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 5: # access | |||
cid_ind = np.where(labels == 5) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 7: # beauty | |||
cid_ind = np.where(labels == 6) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 9: # furniture | |||
cid_ind = np.where(labels == 8) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 21: # underwear | |||
cid_ind = np.where(labels == 9) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
elif cid == 22: # digital | |||
cid_ind = np.where(labels == 11) | |||
scores = bboxes[cid_ind, -1] # 3 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
else: # other | |||
cid_ind5 = np.where(labels == 7) # bottle | |||
cid_ind6 = np.where(labels == 10) # toy | |||
cid_ind7 = np.where(labels == 12) # toy | |||
cid_ind = np.hstack( | |||
(cid_ind5[0], cid_ind6[0], cid_ind7[0])) | |||
scores = bboxes[cid_ind, -1] # 5,6,7 | |||
if scores.size > 0: | |||
bboxes = bboxes[cid_ind] | |||
labels = labels[cid_ind] | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
else: | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
else: | |||
bboxes, labels = self.get_new_box_order( | |||
bboxes, labels, img_h, img_w) | |||
top1_bbox = bboxes[0].astype(np.int32) | |||
top1_bbox[0] = min(max(0, top1_bbox[0]), img_input.shape[1] - 1) | |||
top1_bbox[1] = min(max(0, top1_bbox[1]), img_input.shape[0] - 1) | |||
top1_bbox[2] = max(min(img_input.shape[1] - 1, top1_bbox[2]), 0) | |||
top1_bbox[3] = max(min(img_input.shape[0] - 1, top1_bbox[3]), 0) | |||
if not self.multi_detect: | |||
top1_bbox_str = str(top1_bbox[0]) + ',' + str( | |||
top1_bbox[2]) + ',' + str(top1_bbox[1]) + ',' + str( | |||
top1_bbox[3]) # x1, x2, y1, y2 | |||
crop_img = img_input[top1_bbox[1]:top1_bbox[3], | |||
top1_bbox[0]:top1_bbox[2], :] | |||
coord = top1_bbox_str | |||
coord = '' | |||
for i in range(0, len(bboxes)): | |||
top_bbox = bboxes[i].astype(np.int32) | |||
top_bbox[0] = min( | |||
max(0, top_bbox[0]), img_input.shape[1] - 1) | |||
top_bbox[1] = min( | |||
max(0, top_bbox[1]), img_input.shape[0] - 1) | |||
top_bbox[2] = max( | |||
min(img_input.shape[1] - 1, top_bbox[2]), 0) | |||
top_bbox[3] = max( | |||
min(img_input.shape[0] - 1, top_bbox[3]), 0) | |||
coord = coord + str(top_bbox[0]) + ',' + str( | |||
top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str( | |||
top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str( | |||
bboxes[i][5]) + ';' | |||
else: | |||
coord = '' | |||
for i in range(0, len(bboxes)): | |||
top_bbox = bboxes[i].astype(np.int32) | |||
top_bbox[0] = min( | |||
max(0, top_bbox[0]), img_input.shape[1] - 1) | |||
top_bbox[1] = min( | |||
max(0, top_bbox[1]), img_input.shape[0] - 1) | |||
top_bbox[2] = max( | |||
min(img_input.shape[1] - 1, top_bbox[2]), 0) | |||
top_bbox[3] = max( | |||
min(img_input.shape[0] - 1, top_bbox[3]), 0) | |||
coord = coord + str(top_bbox[0]) + ',' + str( | |||
top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str( | |||
top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str( | |||
bboxes[i][5]) + ';' # x1, x2, y1, y2, conf | |||
crop_img = img_input[top1_bbox[1]:top1_bbox[3], | |||
top1_bbox[0]:top1_bbox[2], :] | |||
crop_img = cv2.resize(crop_img, (224, 224)) | |||
return coord, crop_img # return top1 image and coord |
@@ -0,0 +1,157 @@ | |||
import os | |||
import time | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
def gn_init(m, zero_init=False): | |||
assert isinstance(m, nn.GroupNorm) | |||
m.weight.data.fill_(0. if zero_init else 1.) | |||
m.bias.data.zero_() | |||
class Bottleneck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
"""Bottleneck for resnet-style networks | |||
Args: | |||
inplanes: input channel number | |||
planes: output channel number | |||
""" | |||
super(Bottleneck, self).__init__() | |||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||
self.bn1 = nn.GroupNorm(32, planes) | |||
self.conv2 = nn.Conv2d( | |||
planes, | |||
planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False) | |||
self.bn2 = nn.GroupNorm(32, planes) | |||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |||
self.bn3 = nn.GroupNorm(32, planes * 4) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
gn_init(self.bn1) | |||
gn_init(self.bn2) | |||
gn_init(self.bn3, zero_init=True) | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class ResNet(nn.Module): | |||
""" | |||
resnet-style network with group normalization | |||
""" | |||
def __init__(self, block, layers, num_classes=1000): | |||
self.inplanes = 64 | |||
super(ResNet, self).__init__() | |||
self.conv1 = nn.Conv2d( | |||
3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
self.bn1 = nn.GroupNorm(32, 64) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer(block, 64, layers[0]) | |||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||
self.layer4 = self._make_layer(block, 512, layers[3], stride=1) | |||
self.gap = nn.AvgPool2d((14, 14)) | |||
self.reduce_conv = nn.Conv2d(2048, 512, kernel_size=1) | |||
gn_init(self.bn1) | |||
def _make_layer(self, block, planes, blocks, stride=1): | |||
downsample = None | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
nn.AvgPool2d(stride, stride), | |||
nn.Conv2d( | |||
self.inplanes, | |||
planes * block.expansion, | |||
kernel_size=1, | |||
stride=1, | |||
bias=False), | |||
nn.GroupNorm(32, planes * block.expansion), | |||
) | |||
layers = [] | |||
layers.append(block(self.inplanes, planes, stride, downsample)) | |||
self.inplanes = planes * block.expansion | |||
for i in range(1, blocks): | |||
layers.append(block(self.inplanes, planes)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.gap(x) | |||
x = self.reduce_conv(x) # 512 | |||
x = x.view(x.size(0), -1) # 512 | |||
return F.normalize(x, p=2, dim=1) | |||
def preprocess(img): | |||
""" | |||
preprocess the image with cv2-bgr style to tensor | |||
""" | |||
mean = np.array([0.485, 0.456, 0.406]) | |||
std = np.array([0.229, 0.224, 0.225]) | |||
img_size = 224 | |||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||
img_new = cv2.resize( | |||
img, (img_size, img_size), interpolation=cv2.INTER_LINEAR) | |||
content = np.array(img_new).astype(np.float32) | |||
content = (content / 255.0 - mean) / std | |||
# transpose | |||
img_new = content.transpose(2, 0, 1) | |||
img_new = img_new[np.newaxis, :, :, :] | |||
return img_new | |||
def resnet50_embed(): | |||
""" | |||
create resnet50 network with group normalization | |||
""" | |||
net = ResNet(Bottleneck, [3, 4, 6, 3]) | |||
return net |
@@ -0,0 +1,115 @@ | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.cv.product_retrieval_embedding.item_detection import \ | |||
YOLOXONNX | |||
from modelscope.models.cv.product_retrieval_embedding.item_embedding import ( | |||
preprocess, resnet50_embed) | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.torch_utils import create_device | |||
logger = get_logger() | |||
__all__ = ['ProductRetrievalEmbedding'] | |||
@MODELS.register_module( | |||
Tasks.product_retrieval_embedding, | |||
module_name=Models.product_retrieval_embedding) | |||
class ProductRetrievalEmbedding(TorchModel): | |||
def __init__(self, model_dir, device='cpu', **kwargs): | |||
super().__init__(model_dir=model_dir, device=device, **kwargs) | |||
def filter_param(src_params, own_state): | |||
copied_keys = [] | |||
for name, param in src_params.items(): | |||
if 'module.' == name[0:7]: | |||
name = name[7:] | |||
if '.module.' not in list(own_state.keys())[0]: | |||
name = name.replace('.module.', '.') | |||
if (name in own_state) and (own_state[name].shape | |||
== param.shape): | |||
own_state[name].copy_(param) | |||
copied_keys.append(name) | |||
def load_pretrained(model, src_params): | |||
if 'state_dict' in src_params: | |||
src_params = src_params['state_dict'] | |||
own_state = model.state_dict() | |||
filter_param(src_params, own_state) | |||
model.load_state_dict(own_state) | |||
cpu_flag = device == 'cpu' | |||
self.device = create_device( | |||
cpu_flag) # device.type == "cpu" or device.type == "cuda" | |||
self.use_gpu = self.device.type == 'cuda' | |||
# config the model path | |||
self.local_model_dir = model_dir | |||
# init feat model | |||
self.preprocess_for_embed = preprocess # input is cv2 bgr format | |||
model_feat = resnet50_embed() | |||
src_params = torch.load( | |||
osp.join(self.local_model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||
'cpu') | |||
load_pretrained(model_feat, src_params) | |||
if self.use_gpu: | |||
model_feat.to(self.device) | |||
logger.info('Use GPU: {}'.format(self.device)) | |||
else: | |||
logger.info('Use CPU for inference') | |||
self.model_feat = model_feat | |||
# init det model | |||
self.model_det = YOLOXONNX( | |||
onnx_path=osp.join(self.local_model_dir, 'onnx_detection.onnx'), | |||
multi_detect=False) | |||
logger.info('load model done') | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
""" | |||
detection and feature extraction for input product image | |||
""" | |||
# input should be cv2 bgr format | |||
assert 'img' in input.keys() | |||
def set_phase(model, is_train): | |||
if is_train: | |||
model.train() | |||
else: | |||
model.eval() | |||
is_train = False | |||
set_phase(self.model_feat, is_train) | |||
img = input['img'] # for detection | |||
cid = '3' # preprocess detection category bag | |||
# transform img(tensor) to numpy array with bgr | |||
if isinstance(img, torch.Tensor): | |||
img = img.data.cpu().numpy() | |||
res, crop_img = self.model_det.forward(img, | |||
cid) # detect with bag category | |||
crop_img = self.preprocess_for_embed(crop_img) # feat preprocess | |||
input_tensor = torch.from_numpy(crop_img.astype(np.float32)) | |||
device = next(self.model_feat.parameters()).device | |||
use_gpu = device.type == 'cuda' | |||
with torch.no_grad(): | |||
if use_gpu: | |||
input_tensor = input_tensor.to(device) | |||
out_embedding = self.model_feat(input_tensor) | |||
out_embedding = out_embedding.cpu().numpy()[ | |||
0, :] # feature array with 512 elements | |||
output = {OutputKeys.IMG_EMBEDDING: None} | |||
output[OutputKeys.IMG_EMBEDDING] = out_embedding | |||
return output |
@@ -140,6 +140,12 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.ocr_detection: [OutputKeys.POLYGONS], | |||
# image embedding result for a single image | |||
# { | |||
# "image_bedding": np.array with shape [D] | |||
# } | |||
Tasks.product_retrieval_embedding: [OutputKeys.IMG_EMBEDDING], | |||
# video embedding result for single video | |||
# { | |||
# "video_embedding": np.array with shape [D], | |||
@@ -109,6 +109,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/cv_gan_face-image-generation'), | |||
Tasks.image_super_resolution: (Pipelines.image_super_resolution, | |||
'damo/cv_rrdb_image-super-resolution'), | |||
Tasks.product_retrieval_embedding: | |||
(Pipelines.product_retrieval_embedding, | |||
'damo/cv_resnet50_product-bag-embedding-models'), | |||
Tasks.image_classification_imagenet: | |||
(Pipelines.general_image_classification, | |||
'damo/cv_vit-base_image-classification_ImageNet-labels'), | |||
@@ -11,6 +11,7 @@ if TYPE_CHECKING: | |||
from .face_detection_pipeline import FaceDetectionPipeline | |||
from .face_recognition_pipeline import FaceRecognitionPipeline | |||
from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
from .image_classification_pipeline import ImageClassificationPipeline | |||
from .image_cartoon_pipeline import ImageCartoonPipeline | |||
from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
from .image_denoise_pipeline import ImageDenoisePipeline | |||
@@ -20,12 +21,12 @@ if TYPE_CHECKING: | |||
from .image_matting_pipeline import ImageMattingPipeline | |||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | |||
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | |||
from .style_transfer_pipeline import StyleTransferPipeline | |||
from .live_category_pipeline import LiveCategoryPipeline | |||
from .ocr_detection_pipeline import OCRDetectionPipeline | |||
from .video_category_pipeline import VideoCategoryPipeline | |||
from .virtual_tryon_pipeline import VirtualTryonPipeline | |||
from .image_classification_pipeline import ImageClassificationPipeline | |||
else: | |||
_import_structure = { | |||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
@@ -47,6 +48,8 @@ else: | |||
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
'image_to_image_translation_pipeline': | |||
['Image2ImageTranslationPipeline'], | |||
'product_retrieval_embedding_pipeline': | |||
['ProductRetrievalEmbeddingPipeline'], | |||
'live_category_pipeline': ['LiveCategoryPipeline'], | |||
'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
'style_transfer_pipeline': ['StyleTransferPipeline'], | |||
@@ -0,0 +1,45 @@ | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.pipelines.base import Input, Model, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.product_retrieval_embedding, | |||
module_name=Pipelines.product_retrieval_embedding) | |||
class ProductRetrievalEmbeddingPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
"""use `model` to create a pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
""" | |||
preprocess the input image to cv2-bgr style | |||
""" | |||
img = LoadImage.convert_to_ndarray(input) # array with rgb | |||
img = np.ascontiguousarray(img[:, :, ::-1]) # array with bgr | |||
result = {'img': img} # only for detection | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
return self.model(input) | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -37,6 +37,7 @@ class CVTasks(object): | |||
face_image_generation = 'face-image-generation' | |||
image_super_resolution = 'image-super-resolution' | |||
style_transfer = 'style-transfer' | |||
product_retrieval_embedding = 'product-retrieval-embedding' | |||
live_category = 'live-category' | |||
video_category = 'video-category' | |||
image_classification_imagenet = 'image-classification-imagenet' | |||
@@ -1,5 +1,8 @@ | |||
decord>=0.6.0 | |||
easydict | |||
# tensorflow 1.x compatability requires numpy version to be cap at 1.18 | |||
numpy<=1.18 | |||
onnxruntime>=1.10 | |||
tf_slim | |||
timm | |||
torchvision |
@@ -4,7 +4,8 @@ easydict | |||
einops | |||
filelock>=3.3.0 | |||
gast>=0.2.2 | |||
numpy | |||
# tensorflow 1.x compatability requires numpy version to be cap at 1.18 | |||
numpy<=1.18 | |||
opencv-python | |||
oss2 | |||
Pillow>=6.2.0 | |||
@@ -0,0 +1,39 @@ | |||
import unittest | |||
import numpy as np | |||
from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class ProductRetrievalEmbeddingTest(unittest.TestCase): | |||
model_id = 'damo/cv_resnet50_product-bag-embedding-models' | |||
img_input = 'data/test/images/product_embed_bag.jpg' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
product_embed = pipeline(Tasks.product_retrieval_embedding, | |||
self.model_id) | |||
result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] | |||
print('abs sum value is: {}'.format(np.sum(np.abs(result)))) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id) | |||
product_embed = pipeline( | |||
task=Tasks.product_retrieval_embedding, model=model) | |||
result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] | |||
print('abs sum value is: {}'.format(np.sum(np.abs(result)))) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
product_embed = pipeline(task=Tasks.product_retrieval_embedding) | |||
result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] | |||
print('abs sum value is: {}'.format(np.sum(np.abs(result)))) | |||
if __name__ == '__main__': | |||
unittest.main() |