| @@ -7,6 +7,7 @@ | |||||
| * DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset. | * DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset. | ||||
| * DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark. | * DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark. | ||||
| * Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset. | * Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset. | ||||
| * SSD: a single stage object detection methods on COCO 2017 dataset. | |||||
| * GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset. | * GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset. | ||||
| * Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset. | * Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset. | ||||
| * Frontend and User Interface | * Frontend and User Interface | ||||
| @@ -60,10 +60,10 @@ To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will ge | |||||
| - Distribute mode | - Distribute mode | ||||
| ``` | ``` | ||||
| sh run_distribute_train.sh 8 150 coco /data/hccl.json | |||||
| sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json | |||||
| ``` | ``` | ||||
| The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.** | |||||
| The input parameters are device numbers, epoch size, learning rate, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.** | |||||
| You will get the loss value of each step as following: | You will get the loss value of each step as following: | ||||
| @@ -75,14 +75,15 @@ epoch: 3 step: 455, loss is 5.458992 | |||||
| epoch: 148 step: 455, loss is 1.8340507 | epoch: 148 step: 455, loss is 1.8340507 | ||||
| epoch: 149 step: 455, loss is 2.0876894 | epoch: 149 step: 455, loss is 2.0876894 | ||||
| epoch: 150 step: 455, loss is 2.239692 | epoch: 150 step: 455, loss is 2.239692 | ||||
| ... | |||||
| ``` | ``` | ||||
| ### Evaluation | ### Evaluation | ||||
| for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file. | |||||
| for evaluation , run `eval.py` with `checkpoint_path`. `checkpoint_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file. | |||||
| ``` | ``` | ||||
| python eval.py --ckpt_path ssd.ckpt --dataset coco | |||||
| python eval.py --checkpoint_path ssd.ckpt --dataset coco | |||||
| ``` | ``` | ||||
| You can run ```python eval.py -h``` to get more information. | You can run ```python eval.py -h``` to get more information. | ||||
| @@ -27,6 +27,9 @@ class ConfigSSD: | |||||
| NUM_SSD_BOXES = 1917 | NUM_SSD_BOXES = 1917 | ||||
| NEG_PRE_POSITIVE = 3 | NEG_PRE_POSITIVE = 3 | ||||
| MATCH_THRESHOLD = 0.5 | MATCH_THRESHOLD = 0.5 | ||||
| NMS_THRESHOLD = 0.6 | |||||
| MIN_SCORE = 0.05 | |||||
| TOP_K = 100 | |||||
| NUM_DEFAULT = [3, 6, 6, 6, 6, 6] | NUM_DEFAULT = [3, 6, 6, 6, 6, 6] | ||||
| EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256] | EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256] | ||||
| @@ -34,20 +37,21 @@ class ConfigSSD: | |||||
| EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2] | EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2] | ||||
| EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] | EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] | ||||
| FEATURE_SIZE = [19, 10, 5, 3, 2, 1] | FEATURE_SIZE = [19, 10, 5, 3, 2, 1] | ||||
| SCALES = [21, 45, 99, 153, 207, 261, 315] | |||||
| ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)] | |||||
| MIN_SCALE = 0.2 | |||||
| MAX_SCALE = 0.95 | |||||
| ASPECT_RATIOS = [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)] | |||||
| STEPS = (16, 32, 64, 100, 150, 300) | STEPS = (16, 32, 64, 100, 150, 300) | ||||
| PRIOR_SCALING = (0.1, 0.2) | PRIOR_SCALING = (0.1, 0.2) | ||||
| # `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path. | # `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path. | ||||
| MINDRECORD_DIR = "MindRecord_COCO" | |||||
| COCO_ROOT = "coco2017" | |||||
| MINDRECORD_DIR = "/data/MindRecord_COCO" | |||||
| COCO_ROOT = "/data/coco2017" | |||||
| TRAIN_DATA_TYPE = "train2017" | TRAIN_DATA_TYPE = "train2017" | ||||
| VAL_DATA_TYPE = "val2017" | VAL_DATA_TYPE = "val2017" | ||||
| INSTANCES_SET = "annotations/instances_{}.json" | INSTANCES_SET = "annotations/instances_{}.json" | ||||
| COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | ||||
| 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', | |||||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | ||||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | ||||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | ||||
| @@ -58,7 +62,7 @@ class ConfigSSD: | |||||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | ||||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | ||||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | ||||
| 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink', | |||||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | 'refrigerator', 'book', 'clock', 'vase', 'scissors', | ||||
| 'teddy bear', 'hair drier', 'toothbrush') | 'teddy bear', 'hair drier', 'toothbrush') | ||||
| NUM_CLASSES = len(COCO_CLASSES) | NUM_CLASSES = len(COCO_CLASSES) | ||||
| @@ -32,36 +32,38 @@ config = ConfigSSD() | |||||
| class GeneratDefaultBoxes(): | class GeneratDefaultBoxes(): | ||||
| """ | """ | ||||
| Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). | Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). | ||||
| `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [x, y, w, h]. | |||||
| `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [x1, y1, x2, y2]. | |||||
| `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. | |||||
| `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| fk = config.IMG_SHAPE[0] / np.array(config.STEPS) | fk = config.IMG_SHAPE[0] / np.array(config.STEPS) | ||||
| scale_rate = (config.MAX_SCALE - config.MIN_SCALE) / (len(config.NUM_DEFAULT) - 1) | |||||
| scales = [config.MIN_SCALE + scale_rate * i for i in range(len(config.NUM_DEFAULT))] + [1.0] | |||||
| self.default_boxes = [] | self.default_boxes = [] | ||||
| for idex, feature_size in enumerate(config.FEATURE_SIZE): | for idex, feature_size in enumerate(config.FEATURE_SIZE): | ||||
| sk1 = config.SCALES[idex] / config.IMG_SHAPE[0] | |||||
| sk2 = config.SCALES[idex + 1] / config.IMG_SHAPE[0] | |||||
| sk1 = scales[idex] | |||||
| sk2 = scales[idex + 1] | |||||
| sk3 = math.sqrt(sk1 * sk2) | sk3 = math.sqrt(sk1 * sk2) | ||||
| if config.NUM_DEFAULT[idex] == 3: | |||||
| all_sizes = [(0.5, 1.0), (1.0, 1.0), (1.0, 0.5)] | |||||
| if idex == 0: | |||||
| w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) | |||||
| all_sizes = [(0.1, 0.1), (w, h), (h, w)] | |||||
| else: | else: | ||||
| all_sizes = [(sk1, sk1), (sk3, sk3)] | |||||
| all_sizes = [(sk1, sk1)] | |||||
| for aspect_ratio in config.ASPECT_RATIOS[idex]: | for aspect_ratio in config.ASPECT_RATIOS[idex]: | ||||
| w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) | w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) | ||||
| all_sizes.append((w, h)) | all_sizes.append((w, h)) | ||||
| all_sizes.append((h, w)) | all_sizes.append((h, w)) | ||||
| all_sizes.append((sk3, sk3)) | |||||
| assert len(all_sizes) == config.NUM_DEFAULT[idex] | assert len(all_sizes) == config.NUM_DEFAULT[idex] | ||||
| for i, j in it.product(range(feature_size), repeat=2): | for i, j in it.product(range(feature_size), repeat=2): | ||||
| for w, h in all_sizes: | for w, h in all_sizes: | ||||
| cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] | cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] | ||||
| box = [np.clip(k, 0, 1) for k in (cx, cy, w, h)] | |||||
| self.default_boxes.append(box) | |||||
| self.default_boxes.append([cy, cx, h, w]) | |||||
| def to_ltrb(cx, cy, w, h): | |||||
| return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2 | |||||
| def to_ltrb(cy, cx, h, w): | |||||
| return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 | |||||
| # For IoU calculation | # For IoU calculation | ||||
| self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') | self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') | ||||
| @@ -70,17 +72,22 @@ class GeneratDefaultBoxes(): | |||||
| default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb | default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb | ||||
| default_boxes = GeneratDefaultBoxes().default_boxes | default_boxes = GeneratDefaultBoxes().default_boxes | ||||
| x1, y1, x2, y2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) | |||||
| y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) | |||||
| vol_anchors = (x2 - x1) * (y2 - y1) | vol_anchors = (x2 - x1) * (y2 - y1) | ||||
| matching_threshold = config.MATCH_THRESHOLD | matching_threshold = config.MATCH_THRESHOLD | ||||
| def _rand(a=0., b=1.): | |||||
| """Generate random.""" | |||||
| return np.random.rand() * (b - a) + a | |||||
| def ssd_bboxes_encode(boxes): | def ssd_bboxes_encode(boxes): | ||||
| """ | """ | ||||
| Labels anchors with ground truth inputs. | Labels anchors with ground truth inputs. | ||||
| Args: | Args: | ||||
| boxex: ground truth with shape [N, 5], for each row, it stores [x, y, w, h, cls]. | |||||
| boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls]. | |||||
| Returns: | Returns: | ||||
| gt_loc: location ground truth with shape [num_anchors, 4]. | gt_loc: location ground truth with shape [num_anchors, 4]. | ||||
| @@ -91,10 +98,10 @@ def ssd_bboxes_encode(boxes): | |||||
| def jaccard_with_anchors(bbox): | def jaccard_with_anchors(bbox): | ||||
| """Compute jaccard score a box and the anchors.""" | """Compute jaccard score a box and the anchors.""" | ||||
| # Intersection bbox and volume. | # Intersection bbox and volume. | ||||
| xmin = np.maximum(x1, bbox[0]) | |||||
| ymin = np.maximum(y1, bbox[1]) | |||||
| xmax = np.minimum(x2, bbox[2]) | |||||
| ymax = np.minimum(y2, bbox[3]) | |||||
| ymin = np.maximum(y1, bbox[0]) | |||||
| xmin = np.maximum(x1, bbox[1]) | |||||
| ymax = np.minimum(y2, bbox[2]) | |||||
| xmax = np.minimum(x2, bbox[3]) | |||||
| w = np.maximum(xmax - xmin, 0.) | w = np.maximum(xmax - xmin, 0.) | ||||
| h = np.maximum(ymax - ymin, 0.) | h = np.maximum(ymax - ymin, 0.) | ||||
| @@ -110,12 +117,11 @@ def ssd_bboxes_encode(boxes): | |||||
| for bbox in boxes: | for bbox in boxes: | ||||
| label = int(bbox[4]) | label = int(bbox[4]) | ||||
| scores = jaccard_with_anchors(bbox) | scores = jaccard_with_anchors(bbox) | ||||
| idx = np.argmax(scores) | |||||
| scores[idx] = 2.0 | |||||
| mask = (scores > matching_threshold) | mask = (scores > matching_threshold) | ||||
| if not np.any(mask): | |||||
| mask[np.argmax(scores)] = True | |||||
| mask = mask & (scores > pre_scores) | mask = mask & (scores > pre_scores) | ||||
| pre_scores = np.maximum(pre_scores, scores) | |||||
| pre_scores = np.maximum(pre_scores, scores * mask) | |||||
| t_label = mask * label + (1 - mask) * t_label | t_label = mask * label + (1 - mask) * t_label | ||||
| for i in range(4): | for i in range(4): | ||||
| t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] | t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] | ||||
| @@ -134,13 +140,13 @@ def ssd_bboxes_encode(boxes): | |||||
| bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1] | bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1] | ||||
| bboxes[index] = bboxes_t | bboxes[index] = bboxes_t | ||||
| num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) | |||||
| return bboxes, t_label.astype(np.int32), num_match_num | |||||
| num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) | |||||
| return bboxes, t_label.astype(np.int32), num_match | |||||
| def ssd_bboxes_decode(boxes, index): | |||||
| """Decode predict boxes to [x, y, w, h]""" | |||||
| boxes_t = boxes[index] | |||||
| default_boxes_t = default_boxes[index] | |||||
| def ssd_bboxes_decode(boxes): | |||||
| """Decode predict boxes to [y, x, h, w]""" | |||||
| boxes_t = boxes.copy() | |||||
| default_boxes_t = default_boxes.copy() | |||||
| boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2] | boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2] | ||||
| boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4] | boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4] | ||||
| @@ -149,41 +155,101 @@ def ssd_bboxes_decode(boxes, index): | |||||
| bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 | bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 | ||||
| bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 | bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 | ||||
| return bboxes | |||||
| return np.clip(bboxes, 0, 1) | |||||
| def preprocess_fn(image, box, is_training): | |||||
| """Preprocess function for dataset.""" | |||||
| def _rand(a=0., b=1.): | |||||
| """Generate random.""" | |||||
| return np.random.rand() * (b - a) + a | |||||
| def intersect(box_a, box_b): | |||||
| """Compute the intersect of two sets of boxes.""" | |||||
| max_yx = np.minimum(box_a[:, 2:4], box_b[2:4]) | |||||
| min_yx = np.maximum(box_a[:, :2], box_b[:2]) | |||||
| inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf) | |||||
| return inter[:, 0] * inter[:, 1] | |||||
| def _infer_data(image, input_shape, box): | |||||
| img_h, img_w, _ = image.shape | |||||
| input_h, input_w = input_shape | |||||
| scale = min(float(input_w) / float(img_w), float(input_h) / float(img_h)) | |||||
| nw = int(img_w * scale) | |||||
| nh = int(img_h * scale) | |||||
| def jaccard_numpy(box_a, box_b): | |||||
| """Compute the jaccard overlap of two sets of boxes.""" | |||||
| inter = intersect(box_a, box_b) | |||||
| area_a = ((box_a[:, 2] - box_a[:, 0]) * | |||||
| (box_a[:, 3] - box_a[:, 1])) | |||||
| area_b = ((box_b[2] - box_b[0]) * | |||||
| (box_b[3] - box_b[1])) | |||||
| union = area_a + area_b - inter | |||||
| return inter / union | |||||
| def random_sample_crop(image, boxes): | |||||
| """Random Crop the image and boxes""" | |||||
| height, width, _ = image.shape | |||||
| min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) | |||||
| if min_iou is None: | |||||
| return image, boxes | |||||
| # max trails (50) | |||||
| for _ in range(50): | |||||
| image_t = image | |||||
| w = _rand(0.3, 1.0) * width | |||||
| h = _rand(0.3, 1.0) * height | |||||
| # aspect ratio constraint b/t .5 & 2 | |||||
| if h / w < 0.5 or h / w > 2: | |||||
| continue | |||||
| left = _rand() * (width - w) | |||||
| top = _rand() * (height - h) | |||||
| rect = np.array([int(top), int(left), int(top+h), int(left+w)]) | |||||
| overlap = jaccard_numpy(boxes, rect) | |||||
| # dropout some boxes | |||||
| drop_mask = overlap > 0 | |||||
| if not drop_mask.any(): | |||||
| continue | |||||
| if overlap[drop_mask].min() < min_iou: | |||||
| continue | |||||
| image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :] | |||||
| centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0 | |||||
| m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) | |||||
| m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) | |||||
| # mask in that both m1 and m2 are true | |||||
| mask = m1 * m2 * drop_mask | |||||
| image = cv2.resize(image, (nw, nh)) | |||||
| # have any valid boxes? try again if not | |||||
| if not mask.any(): | |||||
| continue | |||||
| new_image = np.zeros((input_h, input_w, 3), np.float32) | |||||
| dh = (input_h - nh) // 2 | |||||
| dw = (input_w - nw) // 2 | |||||
| new_image[dh: (nh + dh), dw: (nw + dw), :] = image | |||||
| image = new_image | |||||
| # take only matching gt boxes | |||||
| boxes_t = boxes[mask, :].copy() | |||||
| boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2]) | |||||
| boxes_t[:, :2] -= rect[:2] | |||||
| boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4]) | |||||
| boxes_t[:, 2:4] -= rect[:2] | |||||
| return image_t, boxes_t | |||||
| return image, boxes | |||||
| def preprocess_fn(img_id, image, box, is_training): | |||||
| """Preprocess function for dataset.""" | |||||
| def _infer_data(image, input_shape): | |||||
| img_h, img_w, _ = image.shape | |||||
| input_h, input_w = input_shape | |||||
| image = cv2.resize(image, (input_w, input_h)) | |||||
| #When the channels of image is 1 | #When the channels of image is 1 | ||||
| if len(image.shape) == 2: | if len(image.shape) == 2: | ||||
| image = np.expand_dims(image, axis=-1) | image = np.expand_dims(image, axis=-1) | ||||
| image = np.concatenate([image, image, image], axis=-1) | image = np.concatenate([image, image, image], axis=-1) | ||||
| box = box.astype(np.float32) | |||||
| box[:, [0, 2]] = (box[:, [0, 2]] * scale + dw) / input_w | |||||
| box[:, [1, 3]] = (box[:, [1, 3]] * scale + dh) / input_h | |||||
| return image, np.array((img_h, img_w), np.float32), box | |||||
| return img_id, image, np.array((img_h, img_w), np.float32) | |||||
| def _data_aug(image, box, is_training, image_size=(300, 300)): | def _data_aug(image, box, is_training, image_size=(300, 300)): | ||||
| """Data augmentation function.""" | """Data augmentation function.""" | ||||
| @@ -191,53 +257,34 @@ def preprocess_fn(image, box, is_training): | |||||
| w, h = image_size | w, h = image_size | ||||
| if not is_training: | if not is_training: | ||||
| return _infer_data(image, image_size, box) | |||||
| # Random settings | |||||
| scale_w = _rand(0.75, 1.25) | |||||
| scale_h = _rand(0.75, 1.25) | |||||
| return _infer_data(image, image_size) | |||||
| flip = _rand() < .5 | |||||
| nw = iw * scale_w | |||||
| nh = ih * scale_h | |||||
| scale = min(w / nw, h / nh) | |||||
| nw = int(scale * nw) | |||||
| nh = int(scale * nh) | |||||
| # Random crop | |||||
| box = box.astype(np.float32) | |||||
| image, box = random_sample_crop(image, box) | |||||
| ih, iw, _ = image.shape | |||||
| # Resize image | # Resize image | ||||
| image = cv2.resize(image, (nw, nh)) | |||||
| # place image | |||||
| new_image = np.zeros((h, w, 3), dtype=np.float32) | |||||
| dw = (w - nw) // 2 | |||||
| dh = (h - nh) // 2 | |||||
| new_image[dh:dh + nh, dw:dw + nw, :] = image | |||||
| image = new_image | |||||
| image = cv2.resize(image, (w, h)) | |||||
| # Flip image or not | # Flip image or not | ||||
| flip = _rand() < .5 | |||||
| if flip: | if flip: | ||||
| image = cv2.flip(image, 1, dst=None) | image = cv2.flip(image, 1, dst=None) | ||||
| # Convert image to gray or not | |||||
| gray = _rand() < .25 | |||||
| if gray: | |||||
| image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |||||
| # When the channels of image is 1 | # When the channels of image is 1 | ||||
| if len(image.shape) == 2: | if len(image.shape) == 2: | ||||
| image = np.expand_dims(image, axis=-1) | image = np.expand_dims(image, axis=-1) | ||||
| image = np.concatenate([image, image, image], axis=-1) | image = np.concatenate([image, image, image], axis=-1) | ||||
| box = box.astype(np.float32) | |||||
| # Transform box with shape[x1, y1, x2, y2]. | |||||
| box[:, [0, 2]] = (box[:, [0, 2]] * scale * scale_w + dw) / w | |||||
| box[:, [1, 3]] = (box[:, [1, 3]] * scale * scale_h + dh) / h | |||||
| box[:, [0, 2]] = box[:, [0, 2]] / ih | |||||
| box[:, [1, 3]] = box[:, [1, 3]] / iw | |||||
| if flip: | if flip: | ||||
| box[:, [0, 2]] = 1 - box[:, [2, 0]] | |||||
| box[:, [1, 3]] = 1 - box[:, [3, 1]] | |||||
| box, label, num_match_num = ssd_bboxes_encode(box) | |||||
| return image, box, label, num_match_num | |||||
| box, label, num_match = ssd_bboxes_encode(box) | |||||
| return image, box, label, num_match | |||||
| return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE) | return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE) | ||||
| @@ -265,7 +312,8 @@ def create_coco_label(is_training): | |||||
| classs_dict[cat["id"]] = cat["name"] | classs_dict[cat["id"]] = cat["name"] | ||||
| image_ids = coco.getImgIds() | image_ids = coco.getImgIds() | ||||
| image_files = [] | |||||
| images = [] | |||||
| image_path_dict = {} | |||||
| image_anno_dict = {} | image_anno_dict = {} | ||||
| for img_id in image_ids: | for img_id in image_ids: | ||||
| @@ -275,17 +323,23 @@ def create_coco_label(is_training): | |||||
| anno = coco.loadAnns(anno_ids) | anno = coco.loadAnns(anno_ids) | ||||
| image_path = os.path.join(coco_root, data_type, file_name) | image_path = os.path.join(coco_root, data_type, file_name) | ||||
| annos = [] | annos = [] | ||||
| iscrowd = False | |||||
| for label in anno: | for label in anno: | ||||
| bbox = label["bbox"] | bbox = label["bbox"] | ||||
| class_name = classs_dict[label["category_id"]] | class_name = classs_dict[label["category_id"]] | ||||
| iscrowd = iscrowd or label["iscrowd"] | |||||
| if class_name in train_cls: | if class_name in train_cls: | ||||
| x_min, x_max = bbox[0], bbox[0] + bbox[2] | x_min, x_max = bbox[0], bbox[0] + bbox[2] | ||||
| y_min, y_max = bbox[1], bbox[1] + bbox[3] | y_min, y_max = bbox[1], bbox[1] + bbox[3] | ||||
| annos.append(list(map(round, [x_min, y_min, x_max, y_max])) + [train_cls_dict[class_name]]) | |||||
| annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]]) | |||||
| if not is_training and iscrowd: | |||||
| continue | |||||
| if len(annos) >= 1: | if len(annos) >= 1: | ||||
| image_files.append(image_path) | |||||
| image_anno_dict[image_path] = np.array(annos) | |||||
| return image_files, image_anno_dict | |||||
| images.append(img_id) | |||||
| image_path_dict[img_id] = image_path | |||||
| image_anno_dict[img_id] = np.array(annos) | |||||
| return images, image_path_dict, image_anno_dict | |||||
| def anno_parser(annos_str): | def anno_parser(annos_str): | ||||
| @@ -299,7 +353,8 @@ def anno_parser(annos_str): | |||||
| def filter_valid_data(image_dir, anno_path): | def filter_valid_data(image_dir, anno_path): | ||||
| """Filter valid image file, which both in image_dir and anno_path.""" | """Filter valid image file, which both in image_dir and anno_path.""" | ||||
| image_files = [] | |||||
| images = [] | |||||
| image_path_dict = {} | |||||
| image_anno_dict = {} | image_anno_dict = {} | ||||
| if not os.path.isdir(image_dir): | if not os.path.isdir(image_dir): | ||||
| raise RuntimeError("Path given is not valid.") | raise RuntimeError("Path given is not valid.") | ||||
| @@ -308,15 +363,17 @@ def filter_valid_data(image_dir, anno_path): | |||||
| with open(anno_path, "rb") as f: | with open(anno_path, "rb") as f: | ||||
| lines = f.readlines() | lines = f.readlines() | ||||
| for line in lines: | |||||
| for img_id, line in enumerate(lines): | |||||
| line_str = line.decode("utf-8").strip() | line_str = line.decode("utf-8").strip() | ||||
| line_split = str(line_str).split(' ') | line_split = str(line_str).split(' ') | ||||
| file_name = line_split[0] | file_name = line_split[0] | ||||
| image_path = os.path.join(image_dir, file_name) | image_path = os.path.join(image_dir, file_name) | ||||
| if os.path.isfile(image_path): | if os.path.isfile(image_path): | ||||
| image_anno_dict[image_path] = anno_parser(line_split[1:]) | |||||
| image_files.append(image_path) | |||||
| return image_files, image_anno_dict | |||||
| images.append(img_id) | |||||
| image_path_dict[img_id] = image_path | |||||
| image_anno_dict[img_id] = anno_parser(line_split[1:]) | |||||
| return images, image_path_dict, image_anno_dict | |||||
| def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): | def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): | ||||
| @@ -325,21 +382,24 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd. | |||||
| mindrecord_path = os.path.join(mindrecord_dir, prefix) | mindrecord_path = os.path.join(mindrecord_dir, prefix) | ||||
| writer = FileWriter(mindrecord_path, file_num) | writer = FileWriter(mindrecord_path, file_num) | ||||
| if dataset == "coco": | if dataset == "coco": | ||||
| image_files, image_anno_dict = create_coco_label(is_training) | |||||
| images, image_path_dict, image_anno_dict = create_coco_label(is_training) | |||||
| else: | else: | ||||
| image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) | |||||
| images, image_path_dict, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) | |||||
| ssd_json = { | ssd_json = { | ||||
| "img_id": {"type": "int32", "shape": [1]}, | |||||
| "image": {"type": "bytes"}, | "image": {"type": "bytes"}, | ||||
| "annotation": {"type": "int32", "shape": [-1, 5]}, | "annotation": {"type": "int32", "shape": [-1, 5]}, | ||||
| } | } | ||||
| writer.add_schema(ssd_json, "ssd_json") | writer.add_schema(ssd_json, "ssd_json") | ||||
| for image_name in image_files: | |||||
| with open(image_name, 'rb') as f: | |||||
| for img_id in images: | |||||
| image_path = image_path_dict[img_id] | |||||
| with open(image_path, 'rb') as f: | |||||
| img = f.read() | img = f.read() | ||||
| annos = np.array(image_anno_dict[image_name], dtype=np.int32) | |||||
| row = {"image": img, "annotation": annos} | |||||
| annos = np.array(image_anno_dict[img_id], dtype=np.int32) | |||||
| img_id = np.array([img_id], dtype=np.int32) | |||||
| row = {"img_id": img_id, "image": img, "annotation": annos} | |||||
| writer.write_raw_data([row]) | writer.write_raw_data([row]) | ||||
| writer.commit() | writer.commit() | ||||
| @@ -347,29 +407,26 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd. | |||||
| def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, | def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, | ||||
| is_training=True, num_parallel_workers=4): | is_training=True, num_parallel_workers=4): | ||||
| """Creatr SSD dataset with MindDataset.""" | """Creatr SSD dataset with MindDataset.""" | ||||
| ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, | |||||
| num_parallel_workers=num_parallel_workers, shuffle=is_training) | |||||
| ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, | |||||
| shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) | |||||
| decode = C.Decode() | decode = C.Decode() | ||||
| ds = ds.map(input_columns=["image"], operations=decode) | ds = ds.map(input_columns=["image"], operations=decode) | ||||
| compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) | |||||
| change_swap_op = C.HWC2CHW() | |||||
| normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) | |||||
| color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) | |||||
| compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training)) | |||||
| if is_training: | if is_training: | ||||
| hwc_to_chw = C.HWC2CHW() | |||||
| ds = ds.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "box", "label", "num_match_num"], | |||||
| columns_order=["image", "box", "label", "num_match_num"], | |||||
| operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.map(input_columns=["image"], operations=hwc_to_chw, python_multiprocessing=True, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_num) | |||||
| output_columns = ["image", "box", "label", "num_match"] | |||||
| trans = [color_adjust_op, normalize_op, change_swap_op] | |||||
| else: | else: | ||||
| hwc_to_chw = C.HWC2CHW() | |||||
| ds = ds.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "image_shape", "annotation"], | |||||
| columns_order=["image", "image_shape", "annotation"], | |||||
| operations=compose_map_func) | |||||
| ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_num) | |||||
| output_columns = ["img_id", "image", "image_shape"] | |||||
| trans = [normalize_op, change_swap_op] | |||||
| ds = ds.map(input_columns=["img_id", "image", "annotation"], | |||||
| output_columns=output_columns, columns_order=output_columns, | |||||
| operations=compose_map_func, python_multiprocessing=is_training, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.map(input_columns=["image"], operations=trans, python_multiprocessing=is_training, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_num) | |||||
| return ds | return ds | ||||
| @@ -17,6 +17,7 @@ | |||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| import time | import time | ||||
| import numpy as np | |||||
| from mindspore import context, Tensor | from mindspore import context, Tensor | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2 | from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2 | ||||
| @@ -26,8 +27,8 @@ from util import metrics | |||||
| def ssd_eval(dataset_path, ckpt_path): | def ssd_eval(dataset_path, ckpt_path): | ||||
| """SSD evaluation.""" | """SSD evaluation.""" | ||||
| ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False) | |||||
| batch_size = 32 | |||||
| ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) | |||||
| net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False) | net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False) | ||||
| print("Load Checkpoint!") | print("Load Checkpoint!") | ||||
| param_dict = load_checkpoint(ckpt_path) | param_dict = load_checkpoint(ckpt_path) | ||||
| @@ -35,28 +36,28 @@ def ssd_eval(dataset_path, ckpt_path): | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| net.set_train(False) | net.set_train(False) | ||||
| i = 1. | |||||
| total = ds.get_dataset_size() | |||||
| i = batch_size | |||||
| total = ds.get_dataset_size() * batch_size | |||||
| start = time.time() | start = time.time() | ||||
| pred_data = [] | pred_data = [] | ||||
| print("\n========================================\n") | print("\n========================================\n") | ||||
| print("total images num: ", total) | print("total images num: ", total) | ||||
| print("Processing, please wait a moment.") | print("Processing, please wait a moment.") | ||||
| for data in ds.create_dict_iterator(): | for data in ds.create_dict_iterator(): | ||||
| img_id = data['img_id'] | |||||
| img_np = data['image'] | img_np = data['image'] | ||||
| image_shape = data['image_shape'] | image_shape = data['image_shape'] | ||||
| annotation = data['annotation'] | |||||
| output = net(Tensor(img_np)) | output = net(Tensor(img_np)) | ||||
| for batch_idx in range(img_np.shape[0]): | for batch_idx in range(img_np.shape[0]): | ||||
| pred_data.append({"boxes": output[0].asnumpy()[batch_idx], | pred_data.append({"boxes": output[0].asnumpy()[batch_idx], | ||||
| "box_scores": output[1].asnumpy()[batch_idx], | "box_scores": output[1].asnumpy()[batch_idx], | ||||
| "annotation": annotation, | |||||
| "image_shape": image_shape}) | |||||
| percent = round(i / total * 100, 2) | |||||
| "img_id": int(np.squeeze(img_id[batch_idx])), | |||||
| "image_shape": image_shape[batch_idx]}) | |||||
| percent = round(i / total * 100., 2) | |||||
| print(f' {str(percent)} [{i}/{total}]', end='\r') | print(f' {str(percent)} [{i}/{total}]', end='\r') | ||||
| i += 1 | |||||
| i += batch_size | |||||
| cost_time = int((time.time() - start) * 1000) | cost_time = int((time.time() - start) * 1000) | ||||
| print(f' 100% [{total}/{total}] cost {cost_time} ms') | print(f' 100% [{total}/{total}] cost {cost_time} ms') | ||||
| mAP = metrics(pred_data) | mAP = metrics(pred_data) | ||||
| @@ -16,11 +16,17 @@ | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH" | |||||
| echo "for example: sh run_distribute_train.sh 8 150 coco /data/hccl.json" | |||||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" | |||||
| echo "for example: sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)" | |||||
| echo "It is better to use absolute path." | echo "It is better to use absolute path." | ||||
| echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script." | |||||
| echo "==============================================================================================================" | |||||
| echo "=================================================================================================================" | |||||
| if [ $# != 5 ] && [ $# != 7 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \ | |||||
| [MINDSPORE_HCCL_CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| # Before start distribute train, first create mindrecord files. | # Before start distribute train, first create mindrecord files. | ||||
| python train.py --only_create_dataset=1 | python train.py --only_create_dataset=1 | ||||
| @@ -29,9 +35,11 @@ echo "After running the scipt, the network runs in the background. The log will | |||||
| export RANK_SIZE=$1 | export RANK_SIZE=$1 | ||||
| EPOCH_SIZE=$2 | EPOCH_SIZE=$2 | ||||
| DATASET=$3 | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$4 | |||||
| LR=$3 | |||||
| DATASET=$4 | |||||
| PRE_TRAINED=$6 | |||||
| PRE_TRAINED_EPOCH_SIZE=$7 | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$5 | |||||
| for((i=0;i<RANK_SIZE;i++)) | for((i=0;i<RANK_SIZE;i++)) | ||||
| do | do | ||||
| @@ -43,12 +51,29 @@ do | |||||
| export RANK_ID=$i | export RANK_ID=$i | ||||
| echo "start training for rank $i, device $DEVICE_ID" | echo "start training for rank $i, device $DEVICE_ID" | ||||
| env > env.log | env > env.log | ||||
| python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=0.4 \ | |||||
| --dataset=$DATASET \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & | |||||
| if [ $# == 5 ] | |||||
| then | |||||
| python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=$LR \ | |||||
| --dataset=$DATASET \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & | |||||
| fi | |||||
| if [ $# == 7 ] | |||||
| then | |||||
| python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=$LR \ | |||||
| --dataset=$DATASET \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --pre_trained=$PRE_TRAINED \ | |||||
| --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ | |||||
| --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & | |||||
| fi | |||||
| cd ../ | cd ../ | ||||
| done | done | ||||
| @@ -16,79 +16,34 @@ | |||||
| """train SSD and get checkpoint files.""" | """train SSD and get checkpoint files.""" | ||||
| import os | import os | ||||
| import math | |||||
| import argparse | import argparse | ||||
| import numpy as np | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context, Tensor | from mindspore import context, Tensor | ||||
| from mindspore.communication.management import init | from mindspore.communication.management import init | ||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor | from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor | ||||
| from mindspore.train import Model, ParallelMode | from mindspore.train import Model, ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 | from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 | ||||
| from config import ConfigSSD | from config import ConfigSSD | ||||
| from dataset import create_ssd_dataset, data_to_mindrecord_byte_image | from dataset import create_ssd_dataset, data_to_mindrecord_byte_image | ||||
| from util import get_lr, init_net_param | |||||
| def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): | |||||
| """ | |||||
| generate learning rate array | |||||
| Args: | |||||
| global_step(int): total steps of the training | |||||
| lr_init(float): init learning rate | |||||
| lr_end(float): end learning rate | |||||
| lr_max(float): max learning rate | |||||
| warmup_epochs(int): number of warmup epochs | |||||
| total_epochs(int): total epoch of training | |||||
| steps_per_epoch(int): steps of one epoch | |||||
| Returns: | |||||
| np.array, learning rate array | |||||
| """ | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| warmup_steps = steps_per_epoch * warmup_epochs | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = lr_init + (lr_max - lr_init) * i / warmup_steps | |||||
| else: | |||||
| lr = lr_end + (lr_max - lr_end) * \ | |||||
| (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. | |||||
| if lr < 0.0: | |||||
| lr = 0.0 | |||||
| lr_each_step.append(lr) | |||||
| current_step = global_step | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| learning_rate = lr_each_step[current_step:] | |||||
| return learning_rate | |||||
| def init_net_param(network, initialize_mode='XavierUniform'): | |||||
| """Init the parameters in net.""" | |||||
| params = network.trainable_params() | |||||
| for p in params: | |||||
| if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | |||||
| p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype())) | |||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser(description="SSD training") | parser = argparse.ArgumentParser(description="SSD training") | ||||
| parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " | parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " | ||||
| "Mindrecord, default is false.") | |||||
| parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") | |||||
| "Mindrecord, default is False.") | |||||
| parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is False.") | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") | parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") | ||||
| parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.") | |||||
| parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.") | |||||
| parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") | parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") | ||||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") | parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") | ||||
| parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") | |||||
| parser.add_argument("--epoch_size", type=int, default=250, help="Epoch size, default is 250.") | |||||
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | ||||
| parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") | parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") | ||||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") | |||||
| parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") | |||||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.") | |||||
| parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| @@ -142,7 +97,8 @@ def main(): | |||||
| dataset_size = dataset.get_dataset_size() | dataset_size = dataset.get_dataset_size() | ||||
| print("Create dataset done!") | print("Create dataset done!") | ||||
| ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config) | |||||
| backbone = ssd_mobilenet_v2() | |||||
| ssd = SSD300(backbone=backbone, config=config) | |||||
| net = SSDWithLossCell(ssd, config) | net = SSDWithLossCell(ssd, config) | ||||
| init_net_param(net) | init_net_param(net) | ||||
| @@ -150,17 +106,19 @@ def main(): | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | ||||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) | ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) | ||||
| lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr, | |||||
| warmup_epochs=max(args_opt.epoch_size // 20, 1), | |||||
| total_epochs=args_opt.epoch_size, | |||||
| steps_per_epoch=dataset_size)) | |||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) | |||||
| net = TrainingWrapper(net, opt, loss_scale) | |||||
| if args_opt.pre_trained: | if args_opt.pre_trained: | ||||
| if args_opt.pre_trained_epoch_size <= 0: | |||||
| raise KeyError("pre_trained_epoch_size must be greater than 0.") | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | param_dict = load_checkpoint(args_opt.pre_trained) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| lr = Tensor(get_lr(global_step=0, lr_init=0.001, lr_end=0.001 * args_opt.lr, lr_max=args_opt.lr, | |||||
| warmup_epochs=2, | |||||
| total_epochs=args_opt.epoch_size, | |||||
| steps_per_epoch=dataset_size)) | |||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, loss_scale) | |||||
| net = TrainingWrapper(net, opt, loss_scale) | |||||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | ||||
| model = Model(net) | model = Model(net) | ||||
| @@ -14,43 +14,83 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """metrics utils""" | """metrics utils""" | ||||
| import os | |||||
| import json | |||||
| import math | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor | |||||
| from mindspore.common.initializer import initializer, TruncatedNormal | |||||
| from config import ConfigSSD | from config import ConfigSSD | ||||
| from dataset import ssd_bboxes_decode | from dataset import ssd_bboxes_decode | ||||
| def calc_iou(bbox_pred, bbox_ground): | |||||
| """Calculate iou of predicted bbox and ground truth.""" | |||||
| bbox_pred = np.expand_dims(bbox_pred, axis=0) | |||||
| pred_w = bbox_pred[:, 2] - bbox_pred[:, 0] | |||||
| pred_h = bbox_pred[:, 3] - bbox_pred[:, 1] | |||||
| pred_area = pred_w * pred_h | |||||
| gt_w = bbox_ground[:, 2] - bbox_ground[:, 0] | |||||
| gt_h = bbox_ground[:, 3] - bbox_ground[:, 1] | |||||
| gt_area = gt_w * gt_h | |||||
| iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0]) | |||||
| ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1]) | |||||
| iw = np.maximum(iw, 0) | |||||
| ih = np.maximum(ih, 0) | |||||
| intersection_area = iw * ih | |||||
| def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): | |||||
| """ | |||||
| generate learning rate array | |||||
| Args: | |||||
| global_step(int): total steps of the training | |||||
| lr_init(float): init learning rate | |||||
| lr_end(float): end learning rate | |||||
| lr_max(float): max learning rate | |||||
| warmup_epochs(int): number of warmup epochs | |||||
| total_epochs(int): total epoch of training | |||||
| steps_per_epoch(int): steps of one epoch | |||||
| Returns: | |||||
| np.array, learning rate array | |||||
| """ | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| warmup_steps = steps_per_epoch * warmup_epochs | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = lr_init + (lr_max - lr_init) * i / warmup_steps | |||||
| else: | |||||
| lr = lr_end + \ | |||||
| (lr_max - lr_end) * \ | |||||
| (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. | |||||
| if lr < 0.0: | |||||
| lr = 0.0 | |||||
| lr_each_step.append(lr) | |||||
| current_step = global_step | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| learning_rate = lr_each_step[current_step:] | |||||
| return learning_rate | |||||
| def init_net_param(network, initialize_mode='TruncatedNormal'): | |||||
| """Init the parameters in net.""" | |||||
| params = network.trainable_params() | |||||
| for p in params: | |||||
| if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | |||||
| if initialize_mode == 'TruncatedNormal': | |||||
| p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape(), p.data.dtype())) | |||||
| else: | |||||
| p.set_parameter_data(initialize_mode, p.data.shape(), p.data.dtype()) | |||||
| union_area = pred_area + gt_area - intersection_area | |||||
| union_area = np.maximum(union_area, np.finfo(float).eps) | |||||
| iou = intersection_area * 1. / union_area | |||||
| return iou | |||||
| def load_backbone_params(network, param_dict): | |||||
| """Init the parameters from pre-train model, default is mobilenetv2.""" | |||||
| for _, param in net.parameters_and_names(): | |||||
| param_name = param.name.replace('network.backbone.', '') | |||||
| name_split = param_name.split('.') | |||||
| if 'features_1' in param_name: | |||||
| param_name = param_name.replace('features_1', 'features') | |||||
| if 'features_2' in param_name: | |||||
| param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:]) | |||||
| if param_name in param_dict: | |||||
| param.set_parameter_data(param_dict[param_name].data) | |||||
| def apply_nms(all_boxes, all_scores, thres, max_boxes): | def apply_nms(all_boxes, all_scores, thres, max_boxes): | ||||
| """Apply NMS to bboxes.""" | """Apply NMS to bboxes.""" | ||||
| x1 = all_boxes[:, 0] | |||||
| y1 = all_boxes[:, 1] | |||||
| x2 = all_boxes[:, 2] | |||||
| y2 = all_boxes[:, 3] | |||||
| y1 = all_boxes[:, 0] | |||||
| x1 = all_boxes[:, 1] | |||||
| y2 = all_boxes[:, 2] | |||||
| x2 = all_boxes[:, 3] | |||||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | areas = (x2 - x1 + 1) * (y2 - y1 + 1) | ||||
| order = all_scores.argsort()[::-1] | order = all_scores.argsort()[::-1] | ||||
| @@ -80,127 +120,73 @@ def apply_nms(all_boxes, all_scores, thres, max_boxes): | |||||
| return keep | return keep | ||||
| def calc_ap(recall, precision): | |||||
| """Calculate AP.""" | |||||
| correct_recall = np.concatenate(([0.], recall, [1.])) | |||||
| correct_precision = np.concatenate(([0.], precision, [0.])) | |||||
| for i in range(correct_recall.size - 1, 0, -1): | |||||
| correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i]) | |||||
| i = np.where(correct_recall[1:] != correct_recall[:-1])[0] | |||||
| ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1]) | |||||
| return ap | |||||
| def metrics(pred_data): | def metrics(pred_data): | ||||
| """Calculate mAP of predicted bboxes.""" | """Calculate mAP of predicted bboxes.""" | ||||
| from pycocotools.coco import COCO | |||||
| from pycocotools.cocoeval import COCOeval | |||||
| config = ConfigSSD() | config = ConfigSSD() | ||||
| num_classes = config.NUM_CLASSES | num_classes = config.NUM_CLASSES | ||||
| all_detections = [None for i in range(num_classes)] | |||||
| all_pred_scores = [None for i in range(num_classes)] | |||||
| all_annotations = [None for i in range(num_classes)] | |||||
| average_precisions = {} | |||||
| num = [0 for i in range(num_classes)] | |||||
| accurate_num = [0 for i in range(num_classes)] | |||||
| coco_root = config.COCO_ROOT | |||||
| data_type = config.VAL_DATA_TYPE | |||||
| for sample in pred_data: | |||||
| pred_boxes = sample['boxes'] | |||||
| boxes_scores = sample['box_scores'] | |||||
| annotation = sample['annotation'] | |||||
| #Classes need to train or test. | |||||
| val_cls = config.COCO_CLASSES | |||||
| val_cls_dict = {} | |||||
| for i, cls in enumerate(val_cls): | |||||
| val_cls_dict[i] = cls | |||||
| annotation = np.squeeze(annotation, axis=0) | |||||
| anno_json = os.path.join(coco_root, config.INSTANCES_SET.format(data_type)) | |||||
| coco_gt = COCO(anno_json) | |||||
| classs_dict = {} | |||||
| cat_ids = coco_gt.loadCats(coco_gt.getCatIds()) | |||||
| for cat in cat_ids: | |||||
| classs_dict[cat["name"]] = cat["id"] | |||||
| pred_labels = np.argmax(boxes_scores, axis=-1) | |||||
| index = np.nonzero(pred_labels) | |||||
| pred_boxes = ssd_bboxes_decode(pred_boxes, index) | |||||
| predictions = [] | |||||
| img_ids = [] | |||||
| pred_boxes = pred_boxes.clip(0, 1) | |||||
| boxes_scores = np.max(boxes_scores, axis=-1) | |||||
| boxes_scores = boxes_scores[index] | |||||
| pred_labels = pred_labels[index] | |||||
| for sample in pred_data: | |||||
| pred_boxes = sample['boxes'] | |||||
| box_scores = sample['box_scores'] | |||||
| img_id = sample['img_id'] | |||||
| h, w = sample['image_shape'] | |||||
| top_k = 50 | |||||
| pred_boxes = ssd_bboxes_decode(pred_boxes) | |||||
| final_boxes = [] | |||||
| final_label = [] | |||||
| final_score = [] | |||||
| img_ids.append(img_id) | |||||
| for c in range(1, num_classes): | for c in range(1, num_classes): | ||||
| if len(pred_labels) >= 1: | |||||
| class_box_scores = boxes_scores[pred_labels == c] | |||||
| class_boxes = pred_boxes[pred_labels == c] | |||||
| nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k) | |||||
| class_box_scores = box_scores[:, c] | |||||
| score_mask = class_box_scores > config.MIN_SCORE | |||||
| class_box_scores = class_box_scores[score_mask] | |||||
| class_boxes = pred_boxes[score_mask] * [h, w, h, w] | |||||
| if score_mask.any(): | |||||
| nms_index = apply_nms(class_boxes, class_box_scores, config.NMS_THRESHOLD, config.TOP_K) | |||||
| class_boxes = class_boxes[nms_index] | class_boxes = class_boxes[nms_index] | ||||
| class_box_scores = class_box_scores[nms_index] | class_box_scores = class_box_scores[nms_index] | ||||
| cmask = class_box_scores > 0.5 | |||||
| class_boxes = class_boxes[cmask] | |||||
| class_box_scores = class_box_scores[cmask] | |||||
| all_detections[c] = class_boxes | |||||
| all_pred_scores[c] = class_box_scores | |||||
| for c in range(1, num_classes): | |||||
| if len(annotation) >= 1: | |||||
| all_annotations[c] = annotation[annotation[:, 4] == c, :4] | |||||
| for c in range(1, num_classes): | |||||
| false_positives = np.zeros((0,)) | |||||
| true_positives = np.zeros((0,)) | |||||
| scores = np.zeros((0,)) | |||||
| num_annotations = 0.0 | |||||
| annotations = all_annotations[c] | |||||
| num_annotations += annotations.shape[0] | |||||
| detections = all_detections[c] | |||||
| pred_scores = all_pred_scores[c] | |||||
| for index, detection in enumerate(detections): | |||||
| scores = np.append(scores, pred_scores[index]) | |||||
| if len(annotations) >= 1: | |||||
| IoUs = calc_iou(detection, annotations) | |||||
| assigned_anno = np.argmax(IoUs) | |||||
| max_overlap = IoUs[assigned_anno] | |||||
| if max_overlap >= 0.5: | |||||
| false_positives = np.append(false_positives, 0) | |||||
| true_positives = np.append(true_positives, 1) | |||||
| else: | |||||
| false_positives = np.append(false_positives, 1) | |||||
| true_positives = np.append(true_positives, 0) | |||||
| else: | |||||
| false_positives = np.append(false_positives, 1) | |||||
| true_positives = np.append(true_positives, 0) | |||||
| if num_annotations == 0: | |||||
| if c not in average_precisions.keys(): | |||||
| average_precisions[c] = 0 | |||||
| continue | |||||
| accurate_num[c] = 1 | |||||
| indices = np.argsort(-scores) | |||||
| false_positives = false_positives[indices] | |||||
| true_positives = true_positives[indices] | |||||
| false_positives = np.cumsum(false_positives) | |||||
| true_positives = np.cumsum(true_positives) | |||||
| recall = true_positives * 1. / num_annotations | |||||
| precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) | |||||
| average_precision = calc_ap(recall, precision) | |||||
| if c not in average_precisions.keys(): | |||||
| average_precisions[c] = average_precision | |||||
| else: | |||||
| average_precisions[c] += average_precision | |||||
| num[c] += 1 | |||||
| count = 0 | |||||
| for key in average_precisions: | |||||
| if num[key] != 0: | |||||
| count += (average_precisions[key] / num[key]) | |||||
| mAP = count * 1. / accurate_num.count(1) | |||||
| return mAP | |||||
| final_boxes += class_boxes.tolist() | |||||
| final_score += class_box_scores.tolist() | |||||
| final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores) | |||||
| for loc, label, score in zip(final_boxes, final_label, final_score): | |||||
| res = {} | |||||
| res['image_id'] = img_id | |||||
| res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]] | |||||
| res['score'] = score | |||||
| res['category_id'] = label | |||||
| predictions.append(res) | |||||
| with open('predictions.json', 'w') as f: | |||||
| json.dump(predictions, f) | |||||
| coco_dt = coco_gt.loadRes('predictions.json') | |||||
| E = COCOeval(coco_gt, coco_dt, iouType='bbox') | |||||
| E.params.imgIds = img_ids | |||||
| E.evaluate() | |||||
| E.accumulate() | |||||
| E.summarize() | |||||
| return E.stats[0] | |||||
| @@ -17,22 +17,13 @@ | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | |||||
| from mindspore import Parameter, context, Tensor | |||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | from mindspore.parallel._auto_parallel_context import auto_parallel_context | ||||
| from mindspore.communication.management import get_group_size | from mindspore.communication.management import get_group_size | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.ops.operations import TensorAdd | |||||
| from mindspore import Parameter | |||||
| def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): | |||||
| weight_shape = (out_channel, in_channel, kernel_size, kernel_size) | |||||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() | |||||
| return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, | |||||
| padding=0, pad_mode=pad_mod, weight_init=weight) | |||||
| def _make_divisible(v, divisor, min_value=None): | def _make_divisible(v, divisor, min_value=None): | ||||
| @@ -46,6 +37,55 @@ def _make_divisible(v, divisor, min_value=None): | |||||
| return new_v | return new_v | ||||
| def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): | |||||
| return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, | |||||
| padding=0, pad_mode=pad_mod, has_bias=True) | |||||
| def _bn(channel): | |||||
| return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.97, | |||||
| gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) | |||||
| def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0): | |||||
| depthwise_conv = DepthwiseConv(in_channel, kernel_size, stride, pad_mode='same', pad=pad) | |||||
| conv = _conv2d(in_channel, out_channel, kernel_size=1) | |||||
| return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv]) | |||||
| class ConvBNReLU(nn.Cell): | |||||
| """ | |||||
| Convolution/Depthwise fused with Batchnorm and ReLU block definition. | |||||
| Args: | |||||
| in_planes (int): Input channel. | |||||
| out_planes (int): Output channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||||
| groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | |||||
| """ | |||||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |||||
| super(ConvBNReLU, self).__init__() | |||||
| padding = 0 | |||||
| if groups == 1: | |||||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', | |||||
| padding=padding) | |||||
| else: | |||||
| conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='same', pad=padding) | |||||
| layers = [conv, _bn(out_planes), nn.ReLU6()] | |||||
| self.features = nn.SequentialCell(layers) | |||||
| def construct(self, x): | |||||
| output = self.features(x) | |||||
| return output | |||||
| class DepthwiseConv(nn.Cell): | class DepthwiseConv(nn.Cell): | ||||
| """ | """ | ||||
| Depthwise Convolution warpper definition. | Depthwise Convolution warpper definition. | ||||
| @@ -64,6 +104,7 @@ class DepthwiseConv(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) | >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) | ||||
| """ | """ | ||||
| def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | ||||
| super(DepthwiseConv, self).__init__() | super(DepthwiseConv, self).__init__() | ||||
| self.has_bias = has_bias | self.has_bias = has_bias | ||||
| @@ -91,42 +132,9 @@ class DepthwiseConv(nn.Cell): | |||||
| return output | return output | ||||
| class ConvBNReLU(nn.Cell): | |||||
| """ | |||||
| Convolution/Depthwise fused with Batchnorm and ReLU block definition. | |||||
| Args: | |||||
| in_planes (int): Input channel. | |||||
| out_planes (int): Output channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||||
| groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | |||||
| """ | |||||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |||||
| super(ConvBNReLU, self).__init__() | |||||
| padding = (kernel_size - 1) // 2 | |||||
| if groups == 1: | |||||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', | |||||
| padding=padding) | |||||
| else: | |||||
| conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) | |||||
| layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] | |||||
| self.features = nn.SequentialCell(layers) | |||||
| def construct(self, x): | |||||
| output = self.features(x) | |||||
| return output | |||||
| class InvertedResidual(nn.Cell): | class InvertedResidual(nn.Cell): | ||||
| """ | """ | ||||
| Mobilenetv2 residual block definition. | |||||
| Residual block definition. | |||||
| Args: | Args: | ||||
| inp (int): Input channel. | inp (int): Input channel. | ||||
| @@ -140,7 +148,7 @@ class InvertedResidual(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> ResidualBlock(3, 256, 1, 1) | >>> ResidualBlock(3, 256, 1, 1) | ||||
| """ | """ | ||||
| def __init__(self, inp, oup, stride, expand_ratio): | |||||
| def __init__(self, inp, oup, stride, expand_ratio, last_relu=False): | |||||
| super(InvertedResidual, self).__init__() | super(InvertedResidual, self).__init__() | ||||
| assert stride in [1, 2] | assert stride in [1, 2] | ||||
| @@ -155,17 +163,21 @@ class InvertedResidual(nn.Cell): | |||||
| ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | ||||
| # pw-linear | # pw-linear | ||||
| nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), | nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), | ||||
| nn.BatchNorm2d(oup), | |||||
| _bn(oup), | |||||
| ]) | ]) | ||||
| self.conv = nn.SequentialCell(layers) | self.conv = nn.SequentialCell(layers) | ||||
| self.add = TensorAdd() | |||||
| self.add = P.TensorAdd() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.last_relu = last_relu | |||||
| self.relu = nn.ReLU6() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| identity = x | identity = x | ||||
| x = self.conv(x) | x = self.conv(x) | ||||
| if self.use_res_connect: | if self.use_res_connect: | ||||
| return self.add(identity, x) | |||||
| x = self.add(identity, x) | |||||
| if self.last_relu: | |||||
| x = self.relu(x) | |||||
| return x | return x | ||||
| @@ -214,10 +226,10 @@ class MultiBox(nn.Cell): | |||||
| loc_layers = [] | loc_layers = [] | ||||
| cls_layers = [] | cls_layers = [] | ||||
| for k, out_channel in enumerate(out_channels): | for k, out_channel in enumerate(out_channels): | ||||
| loc_layers += [_conv2d(out_channel, 4 * num_default[k], | |||||
| kernel_size=3, stride=1, pad_mod='same')] | |||||
| cls_layers += [_conv2d(out_channel, num_classes * num_default[k], | |||||
| kernel_size=3, stride=1, pad_mod='same')] | |||||
| loc_layers += [_last_conv2d(out_channel, 4 * num_default[k], | |||||
| kernel_size=3, stride=1, pad_mod='same', pad=0)] | |||||
| cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k], | |||||
| kernel_size=3, stride=1, pad_mod='same', pad=0)] | |||||
| self.multi_loc_layers = nn.layer.CellList(loc_layers) | self.multi_loc_layers = nn.layer.CellList(loc_layers) | ||||
| self.multi_cls_layers = nn.layer.CellList(cls_layers) | self.multi_cls_layers = nn.layer.CellList(cls_layers) | ||||
| @@ -258,13 +270,14 @@ class SSD300(nn.Cell): | |||||
| strides = config.EXTRAS_STRIDES | strides = config.EXTRAS_STRIDES | ||||
| residual_list = [] | residual_list = [] | ||||
| for i in range(2, len(in_channels)): | for i in range(2, len(in_channels)): | ||||
| residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], expand_ratio=ratios[i]) | |||||
| residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], | |||||
| expand_ratio=ratios[i], last_relu=True) | |||||
| residual_list.append(residual) | residual_list.append(residual) | ||||
| self.multi_residual = nn.layer.CellList(residual_list) | self.multi_residual = nn.layer.CellList(residual_list) | ||||
| self.multi_box = MultiBox(config) | self.multi_box = MultiBox(config) | ||||
| self.is_training = is_training | self.is_training = is_training | ||||
| if not is_training: | if not is_training: | ||||
| self.softmax = P.Softmax() | |||||
| self.activation = P.Sigmoid() | |||||
| def construct(self, x): | def construct(self, x): | ||||
| layer_out_13, output = self.backbone(x) | layer_out_13, output = self.backbone(x) | ||||
| @@ -275,77 +288,42 @@ class SSD300(nn.Cell): | |||||
| multi_feature += (feature,) | multi_feature += (feature,) | ||||
| pred_loc, pred_label = self.multi_box(multi_feature) | pred_loc, pred_label = self.multi_box(multi_feature) | ||||
| if not self.is_training: | if not self.is_training: | ||||
| pred_label = self.softmax(pred_label) | |||||
| pred_label = self.activation(pred_label) | |||||
| return pred_loc, pred_label | return pred_loc, pred_label | ||||
| class LocalizationLoss(nn.Cell): | |||||
| class SigmoidFocalClassificationLoss(nn.Cell): | |||||
| """" | """" | ||||
| Computes the localization loss with SmoothL1Loss. | |||||
| Returns: | |||||
| Tensor, box regression loss. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LocalizationLoss, self).__init__() | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| self.reduce_mean = P.ReduceMean() | |||||
| self.loss = nn.SmoothL1Loss() | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.less = P.Less() | |||||
| def construct(self, pred_loc, gt_loc, gt_label, num_matched_boxes): | |||||
| mask = F.cast(self.less(0, gt_label), mstype.float32) | |||||
| mask = self.expand_dims(mask, -1) | |||||
| smooth_l1 = self.loss(gt_loc, pred_loc) * mask | |||||
| box_loss = self.reduce_sum(smooth_l1, 1) | |||||
| return self.reduce_mean(box_loss / F.cast(num_matched_boxes, mstype.float32), (0, 1)) | |||||
| class ClassificationLoss(nn.Cell): | |||||
| """" | |||||
| Computes the classification loss with hard example mining. | |||||
| Sigmoid focal-loss for classification. | |||||
| Args: | Args: | ||||
| config (Class): The default config of SSD. | |||||
| gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0 | |||||
| alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25 | |||||
| Returns: | Returns: | ||||
| Tensor, classification loss. | |||||
| Tensor, the focal loss. | |||||
| """ | """ | ||||
| def __init__(self, config): | |||||
| super(ClassificationLoss, self).__init__() | |||||
| self.num_classes = config.NUM_CLASSES | |||||
| self.num_boxes = config.NUM_SSD_BOXES | |||||
| self.neg_pre_positive = config.NEG_PRE_POSITIVE | |||||
| self.minimum = P.Minimum() | |||||
| self.less = P.Less() | |||||
| self.sort = P.TopK() | |||||
| self.tile = P.Tile() | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| self.reduce_mean = P.ReduceMean() | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.sort_descend = P.TopK(True) | |||||
| self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | |||||
| def construct(self, pred_label, gt_label, num_matched_boxes): | |||||
| gt_label = F.cast(gt_label, mstype.int32) | |||||
| mask = F.cast(self.less(0, gt_label), mstype.float32) | |||||
| gt_label_shape = F.shape(gt_label) | |||||
| pred_label = F.reshape(pred_label, (-1, self.num_classes)) | |||||
| gt_label = F.reshape(gt_label, (-1,)) | |||||
| cross_entropy = self.cross_entropy(pred_label, gt_label) | |||||
| cross_entropy = F.reshape(cross_entropy, gt_label_shape) | |||||
| # Hard example mining | |||||
| num_matched_boxes = F.reshape(num_matched_boxes, (-1,)) | |||||
| neg_masked_cross_entropy = F.cast(cross_entropy * (1- mask), mstype.float16) | |||||
| _, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes) | |||||
| _, relative_position = self.sort(F.cast(loss_idx, mstype.float16), self.num_boxes) | |||||
| num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes) | |||||
| tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes)) | |||||
| top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32) | |||||
| class_loss = self.reduce_sum(cross_entropy * (mask + top_k_neg_mask), 1) | |||||
| return self.reduce_mean(class_loss / F.cast(num_matched_boxes, mstype.float32), 0) | |||||
| def __init__(self, gamma=2.0, alpha=0.75): | |||||
| super(SigmoidFocalClassificationLoss, self).__init__() | |||||
| self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits() | |||||
| self.sigmoid = P.Sigmoid() | |||||
| self.pow = P.Pow() | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0, mstype.float32) | |||||
| self.off_value = Tensor(0.0, mstype.float32) | |||||
| self.gamma = gamma | |||||
| self.alpha = alpha | |||||
| def construct(self, logits, label): | |||||
| label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value) | |||||
| sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label) | |||||
| sigmoid = self.sigmoid(logits) | |||||
| label = F.cast(label, mstype.float32) | |||||
| p_t = label * sigmoid + (1 - label) * (1 - sigmoid) | |||||
| modulating_factor = self.pow(1 - p_t, self.gamma) | |||||
| alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha) | |||||
| focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy | |||||
| return focal_loss | |||||
| class SSDWithLossCell(nn.Cell): | class SSDWithLossCell(nn.Cell): | ||||
| @@ -362,14 +340,29 @@ class SSDWithLossCell(nn.Cell): | |||||
| def __init__(self, network, config): | def __init__(self, network, config): | ||||
| super(SSDWithLossCell, self).__init__() | super(SSDWithLossCell, self).__init__() | ||||
| self.network = network | self.network = network | ||||
| self.class_loss = ClassificationLoss(config) | |||||
| self.box_loss = LocalizationLoss() | |||||
| self.less = P.Less() | |||||
| self.tile = P.Tile() | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| self.reduce_mean = P.ReduceMean() | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.class_loss = SigmoidFocalClassificationLoss() | |||||
| self.loc_loss = nn.SmoothL1Loss() | |||||
| def construct(self, x, gt_loc, gt_label, num_matched_boxes): | def construct(self, x, gt_loc, gt_label, num_matched_boxes): | ||||
| pred_loc, pred_label = self.network(x) | pred_loc, pred_label = self.network(x) | ||||
| loss_cls = self.class_loss(pred_label, gt_label, num_matched_boxes) | |||||
| loss_loc = self.box_loss(pred_loc, gt_loc, gt_label, num_matched_boxes) | |||||
| return loss_cls + loss_loc | |||||
| mask = F.cast(self.less(0, gt_label), mstype.float32) | |||||
| num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32)) | |||||
| # Localization Loss | |||||
| mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4)) | |||||
| smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc | |||||
| loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1) | |||||
| # Classification Loss | |||||
| loss_cls = self.class_loss(pred_label, gt_label) | |||||
| loss_cls = self.reduce_sum(loss_cls, (1, 2)) | |||||
| return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) | |||||
| class TrainingWrapper(nn.Cell): | class TrainingWrapper(nn.Cell): | ||||
| @@ -415,7 +408,6 @@ class TrainingWrapper(nn.Cell): | |||||
| return F.depend(loss, self.optimizer(grads)) | return F.depend(loss, self.optimizer(grads)) | ||||
| class SSDWithMobileNetV2(nn.Cell): | class SSDWithMobileNetV2(nn.Cell): | ||||
| """ | """ | ||||
| MobileNetV2 architecture for SSD backbone. | MobileNetV2 architecture for SSD backbone. | ||||