Browse Source

change some settings in SSD

tags/v0.3.0-alpha^2
zhaoting 5 years ago
parent
commit
ac12df82d2
9 changed files with 493 additions and 468 deletions
  1. +1
    -0
      RELEASE.md
  2. +5
    -4
      example/ssd_coco2017/README.md
  3. +10
    -6
      example/ssd_coco2017/config.py
  4. +176
    -119
      example/ssd_coco2017/dataset.py
  5. +10
    -9
      example/ssd_coco2017/eval.py
  6. +39
    -14
      example/ssd_coco2017/run_distribute_train.sh
  7. +18
    -60
      example/ssd_coco2017/train.py
  8. +120
    -134
      example/ssd_coco2017/util.py
  9. +114
    -122
      mindspore/model_zoo/ssd.py

+ 1
- 0
RELEASE.md View File

@@ -7,6 +7,7 @@
* DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset. * DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset.
* DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark. * DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark.
* Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset. * Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset.
* SSD: a single stage object detection methods on COCO 2017 dataset.
* GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset. * GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset.
* Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset. * Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset.
* Frontend and User Interface * Frontend and User Interface


+ 5
- 4
example/ssd_coco2017/README.md View File

@@ -60,10 +60,10 @@ To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will ge
- Distribute mode - Distribute mode


``` ```
sh run_distribute_train.sh 8 150 coco /data/hccl.json
sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json
``` ```


The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
The input parameters are device numbers, epoch size, learning rate, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**


You will get the loss value of each step as following: You will get the loss value of each step as following:


@@ -75,14 +75,15 @@ epoch: 3 step: 455, loss is 5.458992
epoch: 148 step: 455, loss is 1.8340507 epoch: 148 step: 455, loss is 1.8340507
epoch: 149 step: 455, loss is 2.0876894 epoch: 149 step: 455, loss is 2.0876894
epoch: 150 step: 455, loss is 2.239692 epoch: 150 step: 455, loss is 2.239692
...
``` ```


### Evaluation ### Evaluation


for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
for evaluation , run `eval.py` with `checkpoint_path`. `checkpoint_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.


``` ```
python eval.py --ckpt_path ssd.ckpt --dataset coco
python eval.py --checkpoint_path ssd.ckpt --dataset coco
``` ```


You can run ```python eval.py -h``` to get more information. You can run ```python eval.py -h``` to get more information.

+ 10
- 6
example/ssd_coco2017/config.py View File

@@ -27,6 +27,9 @@ class ConfigSSD:
NUM_SSD_BOXES = 1917 NUM_SSD_BOXES = 1917
NEG_PRE_POSITIVE = 3 NEG_PRE_POSITIVE = 3
MATCH_THRESHOLD = 0.5 MATCH_THRESHOLD = 0.5
NMS_THRESHOLD = 0.6
MIN_SCORE = 0.05
TOP_K = 100


NUM_DEFAULT = [3, 6, 6, 6, 6, 6] NUM_DEFAULT = [3, 6, 6, 6, 6, 6]
EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256] EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256]
@@ -34,20 +37,21 @@ class ConfigSSD:
EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2] EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2]
EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25]
FEATURE_SIZE = [19, 10, 5, 3, 2, 1] FEATURE_SIZE = [19, 10, 5, 3, 2, 1]
SCALES = [21, 45, 99, 153, 207, 261, 315]
ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
MIN_SCALE = 0.2
MAX_SCALE = 0.95
ASPECT_RATIOS = [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
STEPS = (16, 32, 64, 100, 150, 300) STEPS = (16, 32, 64, 100, 150, 300)
PRIOR_SCALING = (0.1, 0.2) PRIOR_SCALING = (0.1, 0.2)




# `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path. # `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path.
MINDRECORD_DIR = "MindRecord_COCO"
COCO_ROOT = "coco2017"
MINDRECORD_DIR = "/data/MindRecord_COCO"
COCO_ROOT = "/data/coco2017"
TRAIN_DATA_TYPE = "train2017" TRAIN_DATA_TYPE = "train2017"
VAL_DATA_TYPE = "val2017" VAL_DATA_TYPE = "val2017"
INSTANCES_SET = "annotations/instances_{}.json" INSTANCES_SET = "annotations/instances_{}.json"
COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
@@ -58,7 +62,7 @@ class ConfigSSD:
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush') 'teddy bear', 'hair drier', 'toothbrush')
NUM_CLASSES = len(COCO_CLASSES) NUM_CLASSES = len(COCO_CLASSES)

+ 176
- 119
example/ssd_coco2017/dataset.py View File

@@ -32,36 +32,38 @@ config = ConfigSSD()
class GeneratDefaultBoxes(): class GeneratDefaultBoxes():
""" """
Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [x, y, w, h].
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [x1, y1, x2, y2].
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
""" """
def __init__(self): def __init__(self):
fk = config.IMG_SHAPE[0] / np.array(config.STEPS) fk = config.IMG_SHAPE[0] / np.array(config.STEPS)
scale_rate = (config.MAX_SCALE - config.MIN_SCALE) / (len(config.NUM_DEFAULT) - 1)
scales = [config.MIN_SCALE + scale_rate * i for i in range(len(config.NUM_DEFAULT))] + [1.0]
self.default_boxes = [] self.default_boxes = []
for idex, feature_size in enumerate(config.FEATURE_SIZE): for idex, feature_size in enumerate(config.FEATURE_SIZE):
sk1 = config.SCALES[idex] / config.IMG_SHAPE[0]
sk2 = config.SCALES[idex + 1] / config.IMG_SHAPE[0]
sk1 = scales[idex]
sk2 = scales[idex + 1]
sk3 = math.sqrt(sk1 * sk2) sk3 = math.sqrt(sk1 * sk2)
if config.NUM_DEFAULT[idex] == 3:
all_sizes = [(0.5, 1.0), (1.0, 1.0), (1.0, 0.5)]
if idex == 0:
w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2)
all_sizes = [(0.1, 0.1), (w, h), (h, w)]
else: else:
all_sizes = [(sk1, sk1), (sk3, sk3)]
all_sizes = [(sk1, sk1)]
for aspect_ratio in config.ASPECT_RATIOS[idex]: for aspect_ratio in config.ASPECT_RATIOS[idex]:
w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio)
all_sizes.append((w, h)) all_sizes.append((w, h))
all_sizes.append((h, w)) all_sizes.append((h, w))
all_sizes.append((sk3, sk3))


assert len(all_sizes) == config.NUM_DEFAULT[idex] assert len(all_sizes) == config.NUM_DEFAULT[idex]


for i, j in it.product(range(feature_size), repeat=2): for i, j in it.product(range(feature_size), repeat=2):
for w, h in all_sizes: for w, h in all_sizes:
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
box = [np.clip(k, 0, 1) for k in (cx, cy, w, h)]
self.default_boxes.append(box)
self.default_boxes.append([cy, cx, h, w])


def to_ltrb(cx, cy, w, h):
return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
def to_ltrb(cy, cx, h, w):
return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2


# For IoU calculation # For IoU calculation
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
@@ -70,17 +72,22 @@ class GeneratDefaultBoxes():


default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
default_boxes = GeneratDefaultBoxes().default_boxes default_boxes = GeneratDefaultBoxes().default_boxes
x1, y1, x2, y2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
vol_anchors = (x2 - x1) * (y2 - y1) vol_anchors = (x2 - x1) * (y2 - y1)
matching_threshold = config.MATCH_THRESHOLD matching_threshold = config.MATCH_THRESHOLD




def _rand(a=0., b=1.):
"""Generate random."""
return np.random.rand() * (b - a) + a


def ssd_bboxes_encode(boxes): def ssd_bboxes_encode(boxes):
""" """
Labels anchors with ground truth inputs. Labels anchors with ground truth inputs.


Args: Args:
boxex: ground truth with shape [N, 5], for each row, it stores [x, y, w, h, cls].
boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls].


Returns: Returns:
gt_loc: location ground truth with shape [num_anchors, 4]. gt_loc: location ground truth with shape [num_anchors, 4].
@@ -91,10 +98,10 @@ def ssd_bboxes_encode(boxes):
def jaccard_with_anchors(bbox): def jaccard_with_anchors(bbox):
"""Compute jaccard score a box and the anchors.""" """Compute jaccard score a box and the anchors."""
# Intersection bbox and volume. # Intersection bbox and volume.
xmin = np.maximum(x1, bbox[0])
ymin = np.maximum(y1, bbox[1])
xmax = np.minimum(x2, bbox[2])
ymax = np.minimum(y2, bbox[3])
ymin = np.maximum(y1, bbox[0])
xmin = np.maximum(x1, bbox[1])
ymax = np.minimum(y2, bbox[2])
xmax = np.minimum(x2, bbox[3])
w = np.maximum(xmax - xmin, 0.) w = np.maximum(xmax - xmin, 0.)
h = np.maximum(ymax - ymin, 0.) h = np.maximum(ymax - ymin, 0.)


@@ -110,12 +117,11 @@ def ssd_bboxes_encode(boxes):
for bbox in boxes: for bbox in boxes:
label = int(bbox[4]) label = int(bbox[4])
scores = jaccard_with_anchors(bbox) scores = jaccard_with_anchors(bbox)
idx = np.argmax(scores)
scores[idx] = 2.0
mask = (scores > matching_threshold) mask = (scores > matching_threshold)
if not np.any(mask):
mask[np.argmax(scores)] = True

mask = mask & (scores > pre_scores) mask = mask & (scores > pre_scores)
pre_scores = np.maximum(pre_scores, scores)
pre_scores = np.maximum(pre_scores, scores * mask)
t_label = mask * label + (1 - mask) * t_label t_label = mask * label + (1 - mask) * t_label
for i in range(4): for i in range(4):
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
@@ -134,13 +140,13 @@ def ssd_bboxes_encode(boxes):
bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1] bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1]
bboxes[index] = bboxes_t bboxes[index] = bboxes_t


num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
return bboxes, t_label.astype(np.int32), num_match_num
num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
return bboxes, t_label.astype(np.int32), num_match


def ssd_bboxes_decode(boxes, index):
"""Decode predict boxes to [x, y, w, h]"""
boxes_t = boxes[index]
default_boxes_t = default_boxes[index]
def ssd_bboxes_decode(boxes):
"""Decode predict boxes to [y, x, h, w]"""
boxes_t = boxes.copy()
default_boxes_t = default_boxes.copy()
boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2] boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4] boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4]


@@ -149,41 +155,101 @@ def ssd_bboxes_decode(boxes, index):
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2


return bboxes
return np.clip(bboxes, 0, 1)


def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""


def _rand(a=0., b=1.):
"""Generate random."""
return np.random.rand() * (b - a) + a
def intersect(box_a, box_b):
"""Compute the intersect of two sets of boxes."""
max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
min_yx = np.maximum(box_a[:, :2], box_b[:2])
inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
return inter[:, 0] * inter[:, 1]


def _infer_data(image, input_shape, box):
img_h, img_w, _ = image.shape
input_h, input_w = input_shape


scale = min(float(input_w) / float(img_w), float(input_h) / float(img_h))
nw = int(img_w * scale)
nh = int(img_h * scale)
def jaccard_numpy(box_a, box_b):
"""Compute the jaccard overlap of two sets of boxes."""
inter = intersect(box_a, box_b)
area_a = ((box_a[:, 2] - box_a[:, 0]) *
(box_a[:, 3] - box_a[:, 1]))
area_b = ((box_b[2] - box_b[0]) *
(box_b[3] - box_b[1]))
union = area_a + area_b - inter
return inter / union


def random_sample_crop(image, boxes):
"""Random Crop the image and boxes"""
height, width, _ = image.shape
min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9])

if min_iou is None:
return image, boxes

# max trails (50)
for _ in range(50):
image_t = image

w = _rand(0.3, 1.0) * width
h = _rand(0.3, 1.0) * height

# aspect ratio constraint b/t .5 & 2
if h / w < 0.5 or h / w > 2:
continue

left = _rand() * (width - w)
top = _rand() * (height - h)

rect = np.array([int(top), int(left), int(top+h), int(left+w)])
overlap = jaccard_numpy(boxes, rect)

# dropout some boxes
drop_mask = overlap > 0
if not drop_mask.any():
continue

if overlap[drop_mask].min() < min_iou:
continue

image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :]

centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0

m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])

# mask in that both m1 and m2 are true
mask = m1 * m2 * drop_mask


image = cv2.resize(image, (nw, nh))
# have any valid boxes? try again if not
if not mask.any():
continue


new_image = np.zeros((input_h, input_w, 3), np.float32)
dh = (input_h - nh) // 2
dw = (input_w - nw) // 2
new_image[dh: (nh + dh), dw: (nw + dw), :] = image
image = new_image
# take only matching gt boxes
boxes_t = boxes[mask, :].copy()

boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2])
boxes_t[:, :2] -= rect[:2]
boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4])
boxes_t[:, 2:4] -= rect[:2]

return image_t, boxes_t
return image, boxes


def preprocess_fn(img_id, image, box, is_training):
"""Preprocess function for dataset."""
def _infer_data(image, input_shape):
img_h, img_w, _ = image.shape
input_h, input_w = input_shape

image = cv2.resize(image, (input_w, input_h))


#When the channels of image is 1 #When the channels of image is 1
if len(image.shape) == 2: if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1) image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1) image = np.concatenate([image, image, image], axis=-1)


box = box.astype(np.float32)

box[:, [0, 2]] = (box[:, [0, 2]] * scale + dw) / input_w
box[:, [1, 3]] = (box[:, [1, 3]] * scale + dh) / input_h
return image, np.array((img_h, img_w), np.float32), box
return img_id, image, np.array((img_h, img_w), np.float32)


def _data_aug(image, box, is_training, image_size=(300, 300)): def _data_aug(image, box, is_training, image_size=(300, 300)):
"""Data augmentation function.""" """Data augmentation function."""
@@ -191,53 +257,34 @@ def preprocess_fn(image, box, is_training):
w, h = image_size w, h = image_size


if not is_training: if not is_training:
return _infer_data(image, image_size, box)
# Random settings
scale_w = _rand(0.75, 1.25)
scale_h = _rand(0.75, 1.25)
return _infer_data(image, image_size)


flip = _rand() < .5
nw = iw * scale_w
nh = ih * scale_h
scale = min(w / nw, h / nh)
nw = int(scale * nw)
nh = int(scale * nh)
# Random crop
box = box.astype(np.float32)
image, box = random_sample_crop(image, box)
ih, iw, _ = image.shape


# Resize image # Resize image
image = cv2.resize(image, (nw, nh))

# place image
new_image = np.zeros((h, w, 3), dtype=np.float32)
dw = (w - nw) // 2
dh = (h - nh) // 2
new_image[dh:dh + nh, dw:dw + nw, :] = image
image = new_image
image = cv2.resize(image, (w, h))


# Flip image or not # Flip image or not
flip = _rand() < .5
if flip: if flip:
image = cv2.flip(image, 1, dst=None) image = cv2.flip(image, 1, dst=None)


# Convert image to gray or not
gray = _rand() < .25
if gray:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# When the channels of image is 1 # When the channels of image is 1
if len(image.shape) == 2: if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1) image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1) image = np.concatenate([image, image, image], axis=-1)


box = box.astype(np.float32)

# Transform box with shape[x1, y1, x2, y2].
box[:, [0, 2]] = (box[:, [0, 2]] * scale * scale_w + dw) / w
box[:, [1, 3]] = (box[:, [1, 3]] * scale * scale_h + dh) / h
box[:, [0, 2]] = box[:, [0, 2]] / ih
box[:, [1, 3]] = box[:, [1, 3]] / iw


if flip: if flip:
box[:, [0, 2]] = 1 - box[:, [2, 0]]
box[:, [1, 3]] = 1 - box[:, [3, 1]]


box, label, num_match_num = ssd_bboxes_encode(box)
return image, box, label, num_match_num
box, label, num_match = ssd_bboxes_encode(box)
return image, box, label, num_match
return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE) return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE)




@@ -265,7 +312,8 @@ def create_coco_label(is_training):
classs_dict[cat["id"]] = cat["name"] classs_dict[cat["id"]] = cat["name"]


image_ids = coco.getImgIds() image_ids = coco.getImgIds()
image_files = []
images = []
image_path_dict = {}
image_anno_dict = {} image_anno_dict = {}


for img_id in image_ids: for img_id in image_ids:
@@ -275,17 +323,23 @@ def create_coco_label(is_training):
anno = coco.loadAnns(anno_ids) anno = coco.loadAnns(anno_ids)
image_path = os.path.join(coco_root, data_type, file_name) image_path = os.path.join(coco_root, data_type, file_name)
annos = [] annos = []
iscrowd = False
for label in anno: for label in anno:
bbox = label["bbox"] bbox = label["bbox"]
class_name = classs_dict[label["category_id"]] class_name = classs_dict[label["category_id"]]
iscrowd = iscrowd or label["iscrowd"]
if class_name in train_cls: if class_name in train_cls:
x_min, x_max = bbox[0], bbox[0] + bbox[2] x_min, x_max = bbox[0], bbox[0] + bbox[2]
y_min, y_max = bbox[1], bbox[1] + bbox[3] y_min, y_max = bbox[1], bbox[1] + bbox[3]
annos.append(list(map(round, [x_min, y_min, x_max, y_max])) + [train_cls_dict[class_name]])
annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]])
if not is_training and iscrowd:
continue
if len(annos) >= 1: if len(annos) >= 1:
image_files.append(image_path)
image_anno_dict[image_path] = np.array(annos)
return image_files, image_anno_dict
images.append(img_id)
image_path_dict[img_id] = image_path
image_anno_dict[img_id] = np.array(annos)

return images, image_path_dict, image_anno_dict




def anno_parser(annos_str): def anno_parser(annos_str):
@@ -299,7 +353,8 @@ def anno_parser(annos_str):


def filter_valid_data(image_dir, anno_path): def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path.""" """Filter valid image file, which both in image_dir and anno_path."""
image_files = []
images = []
image_path_dict = {}
image_anno_dict = {} image_anno_dict = {}
if not os.path.isdir(image_dir): if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.") raise RuntimeError("Path given is not valid.")
@@ -308,15 +363,17 @@ def filter_valid_data(image_dir, anno_path):


with open(anno_path, "rb") as f: with open(anno_path, "rb") as f:
lines = f.readlines() lines = f.readlines()
for line in lines:
for img_id, line in enumerate(lines):
line_str = line.decode("utf-8").strip() line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ') line_split = str(line_str).split(' ')
file_name = line_split[0] file_name = line_split[0]
image_path = os.path.join(image_dir, file_name) image_path = os.path.join(image_dir, file_name)
if os.path.isfile(image_path): if os.path.isfile(image_path):
image_anno_dict[image_path] = anno_parser(line_split[1:])
image_files.append(image_path)
return image_files, image_anno_dict
images.append(img_id)
image_path_dict[img_id] = image_path
image_anno_dict[img_id] = anno_parser(line_split[1:])

return images, image_path_dict, image_anno_dict




def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8):
@@ -325,21 +382,24 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.
mindrecord_path = os.path.join(mindrecord_dir, prefix) mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num) writer = FileWriter(mindrecord_path, file_num)
if dataset == "coco": if dataset == "coco":
image_files, image_anno_dict = create_coco_label(is_training)
images, image_path_dict, image_anno_dict = create_coco_label(is_training)
else: else:
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)
images, image_path_dict, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)


ssd_json = { ssd_json = {
"img_id": {"type": "int32", "shape": [1]},
"image": {"type": "bytes"}, "image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 5]}, "annotation": {"type": "int32", "shape": [-1, 5]},
} }
writer.add_schema(ssd_json, "ssd_json") writer.add_schema(ssd_json, "ssd_json")


for image_name in image_files:
with open(image_name, 'rb') as f:
for img_id in images:
image_path = image_path_dict[img_id]
with open(image_path, 'rb') as f:
img = f.read() img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
row = {"image": img, "annotation": annos}
annos = np.array(image_anno_dict[img_id], dtype=np.int32)
img_id = np.array([img_id], dtype=np.int32)
row = {"img_id": img_id, "image": img, "annotation": annos}
writer.write_raw_data([row]) writer.write_raw_data([row])
writer.commit() writer.commit()


@@ -347,29 +407,26 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.
def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0,
is_training=True, num_parallel_workers=4): is_training=True, num_parallel_workers=4):
"""Creatr SSD dataset with MindDataset.""" """Creatr SSD dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
num_parallel_workers=num_parallel_workers, shuffle=is_training)
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
decode = C.Decode() decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode) ds = ds.map(input_columns=["image"], operations=decode)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))

change_swap_op = C.HWC2CHW()
normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training))
if is_training: if is_training:
hwc_to_chw = C.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "box", "label", "num_match_num"],
columns_order=["image", "box", "label", "num_match_num"],
operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, python_multiprocessing=True,
num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
output_columns = ["image", "box", "label", "num_match"]
trans = [color_adjust_op, normalize_op, change_swap_op]
else: else:
hwc_to_chw = C.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "annotation"],
columns_order=["image", "image_shape", "annotation"],
operations=compose_map_func)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
output_columns = ["img_id", "image", "image_shape"]
trans = [normalize_op, change_swap_op]
ds = ds.map(input_columns=["img_id", "image", "annotation"],
output_columns=output_columns, columns_order=output_columns,
operations=compose_map_func, python_multiprocessing=is_training,
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=trans, python_multiprocessing=is_training,
num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds return ds

+ 10
- 9
example/ssd_coco2017/eval.py View File

@@ -17,6 +17,7 @@
import os import os
import argparse import argparse
import time import time
import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2 from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2
@@ -26,8 +27,8 @@ from util import metrics


def ssd_eval(dataset_path, ckpt_path): def ssd_eval(dataset_path, ckpt_path):
"""SSD evaluation.""" """SSD evaluation."""
ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False)
batch_size = 32
ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False)
net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False) net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False)
print("Load Checkpoint!") print("Load Checkpoint!")
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
@@ -35,28 +36,28 @@ def ssd_eval(dataset_path, ckpt_path):
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)


net.set_train(False) net.set_train(False)
i = 1.
total = ds.get_dataset_size()
i = batch_size
total = ds.get_dataset_size() * batch_size
start = time.time() start = time.time()
pred_data = [] pred_data = []
print("\n========================================\n") print("\n========================================\n")
print("total images num: ", total) print("total images num: ", total)
print("Processing, please wait a moment.") print("Processing, please wait a moment.")
for data in ds.create_dict_iterator(): for data in ds.create_dict_iterator():
img_id = data['img_id']
img_np = data['image'] img_np = data['image']
image_shape = data['image_shape'] image_shape = data['image_shape']
annotation = data['annotation']


output = net(Tensor(img_np)) output = net(Tensor(img_np))
for batch_idx in range(img_np.shape[0]): for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx], pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx], "box_scores": output[1].asnumpy()[batch_idx],
"annotation": annotation,
"image_shape": image_shape})
percent = round(i / total * 100, 2)
"img_id": int(np.squeeze(img_id[batch_idx])),
"image_shape": image_shape[batch_idx]})
percent = round(i / total * 100., 2)


print(f' {str(percent)} [{i}/{total}]', end='\r') print(f' {str(percent)} [{i}/{total}]', end='\r')
i += 1
i += batch_size
cost_time = int((time.time() - start) * 1000) cost_time = int((time.time() - start) * 1000)
print(f' 100% [{total}/{total}] cost {cost_time} ms') print(f' 100% [{total}/{total}] cost {cost_time} ms')
mAP = metrics(pred_data) mAP = metrics(pred_data)


+ 39
- 14
example/ssd_coco2017/run_distribute_train.sh View File

@@ -16,11 +16,17 @@


echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the scipt as: "
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: sh run_distribute_train.sh 8 150 coco /data/hccl.json"
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
echo "for example: sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)"
echo "It is better to use absolute path." echo "It is better to use absolute path."
echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script."
echo "=============================================================================================================="
echo "================================================================================================================="

if [ $# != 5 ] && [ $# != 7 ]
then
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \
[MINDSPORE_HCCL_CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
exit 1
fi


# Before start distribute train, first create mindrecord files. # Before start distribute train, first create mindrecord files.
python train.py --only_create_dataset=1 python train.py --only_create_dataset=1
@@ -29,9 +35,11 @@ echo "After running the scipt, the network runs in the background. The log will


export RANK_SIZE=$1 export RANK_SIZE=$1
EPOCH_SIZE=$2 EPOCH_SIZE=$2
DATASET=$3
export MINDSPORE_HCCL_CONFIG_PATH=$4

LR=$3
DATASET=$4
PRE_TRAINED=$6
PRE_TRAINED_EPOCH_SIZE=$7
export MINDSPORE_HCCL_CONFIG_PATH=$5


for((i=0;i<RANK_SIZE;i++)) for((i=0;i<RANK_SIZE;i++))
do do
@@ -43,12 +51,29 @@ do
export RANK_ID=$i export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID" echo "start training for rank $i, device $DEVICE_ID"
env > env.log env > env.log
python ../train.py \
--distribute=1 \
--lr=0.4 \
--dataset=$DATASET \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
if [ $# == 5 ]
then
python ../train.py \
--distribute=1 \
--lr=$LR \
--dataset=$DATASET \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
fi

if [ $# == 7 ]
then
python ../train.py \
--distribute=1 \
--lr=$LR \
--dataset=$DATASET \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--pre_trained=$PRE_TRAINED \
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
fi

cd ../ cd ../
done done

+ 18
- 60
example/ssd_coco2017/train.py View File

@@ -16,79 +16,34 @@
"""train SSD and get checkpoint files.""" """train SSD and get checkpoint files."""


import os import os
import math
import argparse import argparse
import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer

from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
from config import ConfigSSD from config import ConfigSSD
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image from dataset import create_ssd_dataset, data_to_mindrecord_byte_image
from util import get_lr, init_net_param




def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array

Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch

Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + (lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)

current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]

return learning_rate


def init_net_param(network, initialize_mode='XavierUniform'):
"""Init the parameters in net."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype()))

def main(): def main():
parser = argparse.ArgumentParser(description="SSD training") parser = argparse.ArgumentParser(description="SSD training")
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
"Mindrecord, default is false.")
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
"Mindrecord, default is False.")
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is False.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.")
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.")
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.")
parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.")
parser.add_argument("--epoch_size", type=int, default=250, help="Epoch size, default is 250.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
args_opt = parser.parse_args() args_opt = parser.parse_args()


@@ -142,7 +97,8 @@ def main():
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print("Create dataset done!") print("Create dataset done!")


ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config)
backbone = ssd_mobilenet_v2()
ssd = SSD300(backbone=backbone, config=config)
net = SSDWithLossCell(ssd, config) net = SSDWithLossCell(ssd, config)
init_net_param(net) init_net_param(net)


@@ -150,17 +106,19 @@ def main():
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config)


lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr,
warmup_epochs=max(args_opt.epoch_size // 20, 1),
total_epochs=args_opt.epoch_size,
steps_per_epoch=dataset_size))
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale)
net = TrainingWrapper(net, opt, loss_scale)

if args_opt.pre_trained: if args_opt.pre_trained:
if args_opt.pre_trained_epoch_size <= 0:
raise KeyError("pre_trained_epoch_size must be greater than 0.")
param_dict = load_checkpoint(args_opt.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)


lr = Tensor(get_lr(global_step=0, lr_init=0.001, lr_end=0.001 * args_opt.lr, lr_max=args_opt.lr,
warmup_epochs=2,
total_epochs=args_opt.epoch_size,
steps_per_epoch=dataset_size))
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, loss_scale)
net = TrainingWrapper(net, opt, loss_scale)

callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]


model = Model(net) model = Model(net)


+ 120
- 134
example/ssd_coco2017/util.py View File

@@ -14,43 +14,83 @@
# ============================================================================ # ============================================================================
"""metrics utils""" """metrics utils"""


import os
import json
import math
import numpy as np import numpy as np
from mindspore import Tensor
from mindspore.common.initializer import initializer, TruncatedNormal
from config import ConfigSSD from config import ConfigSSD
from dataset import ssd_bboxes_decode from dataset import ssd_bboxes_decode




def calc_iou(bbox_pred, bbox_ground):
"""Calculate iou of predicted bbox and ground truth."""
bbox_pred = np.expand_dims(bbox_pred, axis=0)

pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
pred_area = pred_w * pred_h

gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
gt_area = gt_w * gt_h

iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])

iw = np.maximum(iw, 0)
ih = np.maximum(ih, 0)
intersection_area = iw * ih
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array

Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch

Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)

current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]

return learning_rate


def init_net_param(network, initialize_mode='TruncatedNormal'):
"""Init the parameters in net."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
if initialize_mode == 'TruncatedNormal':
p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape(), p.data.dtype()))
else:
p.set_parameter_data(initialize_mode, p.data.shape(), p.data.dtype())


union_area = pred_area + gt_area - intersection_area
union_area = np.maximum(union_area, np.finfo(float).eps)


iou = intersection_area * 1. / union_area
return iou
def load_backbone_params(network, param_dict):
"""Init the parameters from pre-train model, default is mobilenetv2."""
for _, param in net.parameters_and_names():
param_name = param.name.replace('network.backbone.', '')
name_split = param_name.split('.')
if 'features_1' in param_name:
param_name = param_name.replace('features_1', 'features')
if 'features_2' in param_name:
param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:])
if param_name in param_dict:
param.set_parameter_data(param_dict[param_name].data)




def apply_nms(all_boxes, all_scores, thres, max_boxes): def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes.""" """Apply NMS to bboxes."""
x1 = all_boxes[:, 0]
y1 = all_boxes[:, 1]
x2 = all_boxes[:, 2]
y2 = all_boxes[:, 3]
y1 = all_boxes[:, 0]
x1 = all_boxes[:, 1]
y2 = all_boxes[:, 2]
x2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1) areas = (x2 - x1 + 1) * (y2 - y1 + 1)


order = all_scores.argsort()[::-1] order = all_scores.argsort()[::-1]
@@ -80,127 +120,73 @@ def apply_nms(all_boxes, all_scores, thres, max_boxes):
return keep return keep




def calc_ap(recall, precision):
"""Calculate AP."""
correct_recall = np.concatenate(([0.], recall, [1.]))
correct_precision = np.concatenate(([0.], precision, [0.]))

for i in range(correct_recall.size - 1, 0, -1):
correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])

i = np.where(correct_recall[1:] != correct_recall[:-1])[0]

ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])

return ap

def metrics(pred_data): def metrics(pred_data):
"""Calculate mAP of predicted bboxes.""" """Calculate mAP of predicted bboxes."""
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
config = ConfigSSD() config = ConfigSSD()
num_classes = config.NUM_CLASSES num_classes = config.NUM_CLASSES


all_detections = [None for i in range(num_classes)]
all_pred_scores = [None for i in range(num_classes)]
all_annotations = [None for i in range(num_classes)]
average_precisions = {}
num = [0 for i in range(num_classes)]
accurate_num = [0 for i in range(num_classes)]
coco_root = config.COCO_ROOT
data_type = config.VAL_DATA_TYPE


for sample in pred_data:
pred_boxes = sample['boxes']
boxes_scores = sample['box_scores']
annotation = sample['annotation']
#Classes need to train or test.
val_cls = config.COCO_CLASSES
val_cls_dict = {}
for i, cls in enumerate(val_cls):
val_cls_dict[i] = cls


annotation = np.squeeze(annotation, axis=0)
anno_json = os.path.join(coco_root, config.INSTANCES_SET.format(data_type))
coco_gt = COCO(anno_json)
classs_dict = {}
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
for cat in cat_ids:
classs_dict[cat["name"]] = cat["id"]


pred_labels = np.argmax(boxes_scores, axis=-1)
index = np.nonzero(pred_labels)
pred_boxes = ssd_bboxes_decode(pred_boxes, index)
predictions = []
img_ids = []


pred_boxes = pred_boxes.clip(0, 1)
boxes_scores = np.max(boxes_scores, axis=-1)
boxes_scores = boxes_scores[index]
pred_labels = pred_labels[index]
for sample in pred_data:
pred_boxes = sample['boxes']
box_scores = sample['box_scores']
img_id = sample['img_id']
h, w = sample['image_shape']


top_k = 50
pred_boxes = ssd_bboxes_decode(pred_boxes)
final_boxes = []
final_label = []
final_score = []
img_ids.append(img_id)


for c in range(1, num_classes): for c in range(1, num_classes):
if len(pred_labels) >= 1:
class_box_scores = boxes_scores[pred_labels == c]
class_boxes = pred_boxes[pred_labels == c]

nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
class_box_scores = box_scores[:, c]
score_mask = class_box_scores > config.MIN_SCORE
class_box_scores = class_box_scores[score_mask]
class_boxes = pred_boxes[score_mask] * [h, w, h, w]


if score_mask.any():
nms_index = apply_nms(class_boxes, class_box_scores, config.NMS_THRESHOLD, config.TOP_K)
class_boxes = class_boxes[nms_index] class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index] class_box_scores = class_box_scores[nms_index]


cmask = class_box_scores > 0.5
class_boxes = class_boxes[cmask]
class_box_scores = class_box_scores[cmask]

all_detections[c] = class_boxes
all_pred_scores[c] = class_box_scores

for c in range(1, num_classes):
if len(annotation) >= 1:
all_annotations[c] = annotation[annotation[:, 4] == c, :4]

for c in range(1, num_classes):
false_positives = np.zeros((0,))
true_positives = np.zeros((0,))
scores = np.zeros((0,))
num_annotations = 0.0

annotations = all_annotations[c]
num_annotations += annotations.shape[0]
detections = all_detections[c]
pred_scores = all_pred_scores[c]

for index, detection in enumerate(detections):
scores = np.append(scores, pred_scores[index])
if len(annotations) >= 1:
IoUs = calc_iou(detection, annotations)
assigned_anno = np.argmax(IoUs)
max_overlap = IoUs[assigned_anno]

if max_overlap >= 0.5:
false_positives = np.append(false_positives, 0)
true_positives = np.append(true_positives, 1)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)

if num_annotations == 0:
if c not in average_precisions.keys():
average_precisions[c] = 0
continue
accurate_num[c] = 1
indices = np.argsort(-scores)
false_positives = false_positives[indices]
true_positives = true_positives[indices]

false_positives = np.cumsum(false_positives)
true_positives = np.cumsum(true_positives)

recall = true_positives * 1. / num_annotations
precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)

average_precision = calc_ap(recall, precision)

if c not in average_precisions.keys():
average_precisions[c] = average_precision
else:
average_precisions[c] += average_precision

num[c] += 1

count = 0
for key in average_precisions:
if num[key] != 0:
count += (average_precisions[key] / num[key])

mAP = count * 1. / accurate_num.count(1)
return mAP
final_boxes += class_boxes.tolist()
final_score += class_box_scores.tolist()
final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores)

for loc, label, score in zip(final_boxes, final_label, final_score):
res = {}
res['image_id'] = img_id
res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]
res['score'] = score
res['category_id'] = label
predictions.append(res)
with open('predictions.json', 'w') as f:
json.dump(predictions, f)

coco_dt = coco_gt.loadRes('predictions.json')
E = COCOeval(coco_gt, coco_dt, iouType='bbox')
E.params.imgIds = img_ids
E.evaluate()
E.accumulate()
E.summarize()
return E.stats[0]

+ 114
- 122
mindspore/model_zoo/ssd.py View File

@@ -17,22 +17,13 @@
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context
from mindspore import Parameter, context, Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops.operations import TensorAdd
from mindspore import Parameter


def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'):
weight_shape = (out_channel, in_channel, kernel_size, kernel_size)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
padding=0, pad_mode=pad_mod, weight_init=weight)




def _make_divisible(v, divisor, min_value=None): def _make_divisible(v, divisor, min_value=None):
@@ -46,6 +37,55 @@ def _make_divisible(v, divisor, min_value=None):
return new_v return new_v




def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'):
return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
padding=0, pad_mode=pad_mod, has_bias=True)


def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.97,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0):
depthwise_conv = DepthwiseConv(in_channel, kernel_size, stride, pad_mode='same', pad=pad)
conv = _conv2d(in_channel, out_channel, kernel_size=1)
return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv])


class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.

Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.

Returns:
Tensor, output tensor.

Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = 0
if groups == 1:
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same',
padding=padding)
else:
conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='same', pad=padding)
layers = [conv, _bn(out_planes), nn.ReLU6()]
self.features = nn.SequentialCell(layers)

def construct(self, x):
output = self.features(x)
return output


class DepthwiseConv(nn.Cell): class DepthwiseConv(nn.Cell):
""" """
Depthwise Convolution warpper definition. Depthwise Convolution warpper definition.
@@ -64,6 +104,7 @@ class DepthwiseConv(nn.Cell):
Examples: Examples:
>>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
""" """

def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
super(DepthwiseConv, self).__init__() super(DepthwiseConv, self).__init__()
self.has_bias = has_bias self.has_bias = has_bias
@@ -91,42 +132,9 @@ class DepthwiseConv(nn.Cell):
return output return output




class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.

Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.

Returns:
Tensor, output tensor.

Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
if groups == 1:
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
padding=padding)
else:
conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding)
layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]
self.features = nn.SequentialCell(layers)

def construct(self, x):
output = self.features(x)
return output


class InvertedResidual(nn.Cell): class InvertedResidual(nn.Cell):
""" """
Mobilenetv2 residual block definition.
Residual block definition.


Args: Args:
inp (int): Input channel. inp (int): Input channel.
@@ -140,7 +148,7 @@ class InvertedResidual(nn.Cell):
Examples: Examples:
>>> ResidualBlock(3, 256, 1, 1) >>> ResidualBlock(3, 256, 1, 1)
""" """
def __init__(self, inp, oup, stride, expand_ratio):
def __init__(self, inp, oup, stride, expand_ratio, last_relu=False):
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
assert stride in [1, 2] assert stride in [1, 2]


@@ -155,17 +163,21 @@ class InvertedResidual(nn.Cell):
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear # pw-linear
nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False),
nn.BatchNorm2d(oup),
_bn(oup),
]) ])
self.conv = nn.SequentialCell(layers) self.conv = nn.SequentialCell(layers)
self.add = TensorAdd()
self.add = P.TensorAdd()
self.cast = P.Cast() self.cast = P.Cast()
self.last_relu = last_relu
self.relu = nn.ReLU6()


def construct(self, x): def construct(self, x):
identity = x identity = x
x = self.conv(x) x = self.conv(x)
if self.use_res_connect: if self.use_res_connect:
return self.add(identity, x)
x = self.add(identity, x)
if self.last_relu:
x = self.relu(x)
return x return x




@@ -214,10 +226,10 @@ class MultiBox(nn.Cell):
loc_layers = [] loc_layers = []
cls_layers = [] cls_layers = []
for k, out_channel in enumerate(out_channels): for k, out_channel in enumerate(out_channels):
loc_layers += [_conv2d(out_channel, 4 * num_default[k],
kernel_size=3, stride=1, pad_mod='same')]
cls_layers += [_conv2d(out_channel, num_classes * num_default[k],
kernel_size=3, stride=1, pad_mod='same')]
loc_layers += [_last_conv2d(out_channel, 4 * num_default[k],
kernel_size=3, stride=1, pad_mod='same', pad=0)]
cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k],
kernel_size=3, stride=1, pad_mod='same', pad=0)]


self.multi_loc_layers = nn.layer.CellList(loc_layers) self.multi_loc_layers = nn.layer.CellList(loc_layers)
self.multi_cls_layers = nn.layer.CellList(cls_layers) self.multi_cls_layers = nn.layer.CellList(cls_layers)
@@ -258,13 +270,14 @@ class SSD300(nn.Cell):
strides = config.EXTRAS_STRIDES strides = config.EXTRAS_STRIDES
residual_list = [] residual_list = []
for i in range(2, len(in_channels)): for i in range(2, len(in_channels)):
residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], expand_ratio=ratios[i])
residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i],
expand_ratio=ratios[i], last_relu=True)
residual_list.append(residual) residual_list.append(residual)
self.multi_residual = nn.layer.CellList(residual_list) self.multi_residual = nn.layer.CellList(residual_list)
self.multi_box = MultiBox(config) self.multi_box = MultiBox(config)
self.is_training = is_training self.is_training = is_training
if not is_training: if not is_training:
self.softmax = P.Softmax()
self.activation = P.Sigmoid()


def construct(self, x): def construct(self, x):
layer_out_13, output = self.backbone(x) layer_out_13, output = self.backbone(x)
@@ -275,77 +288,42 @@ class SSD300(nn.Cell):
multi_feature += (feature,) multi_feature += (feature,)
pred_loc, pred_label = self.multi_box(multi_feature) pred_loc, pred_label = self.multi_box(multi_feature)
if not self.is_training: if not self.is_training:
pred_label = self.softmax(pred_label)
pred_label = self.activation(pred_label)
return pred_loc, pred_label return pred_loc, pred_label




class LocalizationLoss(nn.Cell):
class SigmoidFocalClassificationLoss(nn.Cell):
"""" """"
Computes the localization loss with SmoothL1Loss.

Returns:
Tensor, box regression loss.
"""
def __init__(self):
super(LocalizationLoss, self).__init__()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.loss = nn.SmoothL1Loss()
self.expand_dims = P.ExpandDims()
self.less = P.Less()

def construct(self, pred_loc, gt_loc, gt_label, num_matched_boxes):
mask = F.cast(self.less(0, gt_label), mstype.float32)
mask = self.expand_dims(mask, -1)
smooth_l1 = self.loss(gt_loc, pred_loc) * mask
box_loss = self.reduce_sum(smooth_l1, 1)
return self.reduce_mean(box_loss / F.cast(num_matched_boxes, mstype.float32), (0, 1))


class ClassificationLoss(nn.Cell):
""""
Computes the classification loss with hard example mining.
Sigmoid focal-loss for classification.


Args: Args:
config (Class): The default config of SSD.
gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0
alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25


Returns: Returns:
Tensor, classification loss.
Tensor, the focal loss.
""" """
def __init__(self, config):
super(ClassificationLoss, self).__init__()
self.num_classes = config.NUM_CLASSES
self.num_boxes = config.NUM_SSD_BOXES
self.neg_pre_positive = config.NEG_PRE_POSITIVE
self.minimum = P.Minimum()
self.less = P.Less()
self.sort = P.TopK()
self.tile = P.Tile()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.expand_dims = P.ExpandDims()
self.sort_descend = P.TopK(True)
self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits(sparse=True)

def construct(self, pred_label, gt_label, num_matched_boxes):
gt_label = F.cast(gt_label, mstype.int32)
mask = F.cast(self.less(0, gt_label), mstype.float32)
gt_label_shape = F.shape(gt_label)
pred_label = F.reshape(pred_label, (-1, self.num_classes))
gt_label = F.reshape(gt_label, (-1,))
cross_entropy = self.cross_entropy(pred_label, gt_label)
cross_entropy = F.reshape(cross_entropy, gt_label_shape)

# Hard example mining
num_matched_boxes = F.reshape(num_matched_boxes, (-1,))
neg_masked_cross_entropy = F.cast(cross_entropy * (1- mask), mstype.float16)
_, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes)
_, relative_position = self.sort(F.cast(loss_idx, mstype.float16), self.num_boxes)
num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes)
tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes))
top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32)
class_loss = self.reduce_sum(cross_entropy * (mask + top_k_neg_mask), 1)
return self.reduce_mean(class_loss / F.cast(num_matched_boxes, mstype.float32), 0)
def __init__(self, gamma=2.0, alpha=0.75):
super(SigmoidFocalClassificationLoss, self).__init__()
self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.sigmoid = P.Sigmoid()
self.pow = P.Pow()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.gamma = gamma
self.alpha = alpha

def construct(self, logits, label):
label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label)
sigmoid = self.sigmoid(logits)
label = F.cast(label, mstype.float32)
p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
modulating_factor = self.pow(1 - p_t, self.gamma)
alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy
return focal_loss




class SSDWithLossCell(nn.Cell): class SSDWithLossCell(nn.Cell):
@@ -362,14 +340,29 @@ class SSDWithLossCell(nn.Cell):
def __init__(self, network, config): def __init__(self, network, config):
super(SSDWithLossCell, self).__init__() super(SSDWithLossCell, self).__init__()
self.network = network self.network = network
self.class_loss = ClassificationLoss(config)
self.box_loss = LocalizationLoss()
self.less = P.Less()
self.tile = P.Tile()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.expand_dims = P.ExpandDims()
self.class_loss = SigmoidFocalClassificationLoss()
self.loc_loss = nn.SmoothL1Loss()


def construct(self, x, gt_loc, gt_label, num_matched_boxes): def construct(self, x, gt_loc, gt_label, num_matched_boxes):
pred_loc, pred_label = self.network(x) pred_loc, pred_label = self.network(x)
loss_cls = self.class_loss(pred_label, gt_label, num_matched_boxes)
loss_loc = self.box_loss(pred_loc, gt_loc, gt_label, num_matched_boxes)
return loss_cls + loss_loc
mask = F.cast(self.less(0, gt_label), mstype.float32)
num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))

# Localization Loss
mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4))
smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc
loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1)

# Classification Loss
loss_cls = self.class_loss(pred_label, gt_label)
loss_cls = self.reduce_sum(loss_cls, (1, 2))

return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes)




class TrainingWrapper(nn.Cell): class TrainingWrapper(nn.Cell):
@@ -415,7 +408,6 @@ class TrainingWrapper(nn.Cell):
return F.depend(loss, self.optimizer(grads)) return F.depend(loss, self.optimizer(grads))





class SSDWithMobileNetV2(nn.Cell): class SSDWithMobileNetV2(nn.Cell):
""" """
MobileNetV2 architecture for SSD backbone. MobileNetV2 architecture for SSD backbone.


Loading…
Cancel
Save