Browse Source

fix incremental_learning bug

- Add docs and code comment
- fix bugs: epoch always be 1, inference result not saved, s3 upload
  fail

Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>
tags/v0.3.1
JoeyHwong 4 years ago
parent
commit
a218658ed6
7 changed files with 220 additions and 157 deletions
  1. +15
    -11
      examples/incremental_learning/helmet_detection/training/inference.py
  2. +15
    -1
      examples/incremental_learning/helmet_detection/training/interface.py
  3. +1
    -139
      lib/sedna/algorithms/hard_example_mining/__init__.py
  4. +153
    -0
      lib/sedna/algorithms/hard_example_mining/hard_example_mining.py
  5. +1
    -1
      lib/sedna/backend/__init__.py
  6. +3
    -2
      lib/sedna/backend/tensorflow/__init__.py
  7. +32
    -3
      lib/sedna/core/incremental_learning/incremental_learning.py

+ 15
- 11
examples/incremental_learning/helmet_detection/training/inference.py View File

@@ -24,11 +24,11 @@ from sedna.common.file_ops import FileOps
from sedna.core.incremental_learning import IncrementalLearning from sedna.core.incremental_learning import IncrementalLearning
from interface import Estimator from interface import Estimator


he_saved_url = Context.get_parameters("HE_SAVED_URL")
he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp')
rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp')
class_names = ['person', 'helmet', 'helmet_on', 'helmet_off'] class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']


FileOps.clean_folder([he_saved_url], clean=False)
FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False)




def draw_boxes(img, labels, scores, bboxes, class_names, colors): def draw_boxes(img, labels, scores, bboxes, class_names, colors):
@@ -59,11 +59,14 @@ def draw_boxes(img, labels, scores, bboxes, class_names, colors):
p2 = (int(bbox[2]), int(bbox[3])) p2 = (int(bbox[2]), int(bbox[3]))
if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1): if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1):
continue continue
cv2.rectangle(img, p1[::-1], p2[::-1],
colors_code[labels[i]], box_thickness)
cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
text_thickness, line_type)
try:
cv2.rectangle(img, p1[::-1], p2[::-1],
colors_code[labels[i]], box_thickness)
cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
text_thickness, line_type)
except TypeError as err:
warnings.warn(f"Draw box fail: {err}")
return img return img




@@ -72,12 +75,13 @@ def output_deal(is_hard_example, infer_result, nframe, img_rgb):
img_rgb = np.array(img_rgb) img_rgb = np.array(img_rgb)
img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
colors = 'yellow,blue,green,red' colors = 'yellow,blue,green,red'
if not is_hard_example:
return

lables, scores, bbox_list_pred = infer_result lables, scores, bbox_list_pred = infer_result
img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names, img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names,
colors) colors)
cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img)
if is_hard_example:
cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img)
cv2.imwrite(f"{rsl_saved_url}/{nframe}.jpeg", img)




def mkdir(path): def mkdir(path):


+ 15
- 1
examples/incremental_learning/helmet_detection/training/interface.py View File

@@ -15,6 +15,7 @@
import os import os
import six import six
import logging import logging
from urllib.parse import urlparse


import cv2 import cv2
import numpy as np import numpy as np
@@ -26,8 +27,21 @@ from validate_utils import validate
from yolo3_multiscale import Yolo3 from yolo3_multiscale import Yolo3
from yolo3_multiscale import YoloConfig from yolo3_multiscale import YoloConfig



os.environ['BACKEND_TYPE'] = 'TENSORFLOW' os.environ['BACKEND_TYPE'] = 'TENSORFLOW'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
s3_url = os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com")
if not (s3_url.startswith("http://") or s3_url.startswith("https://")):
_url = f"https://{s3_url}"
s3_url = urlparse(s3_url)
s3_use_ssl = s3_url.scheme == 'https' if s3_url.scheme else True

os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("ACCESS_KEY_ID")
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("SECRET_ACCESS_KEY")
os.environ["S3_ENDPOINT"] = s3_url.netloc
os.environ["S3_USE_HTTPS"] = "1" if s3_use_ssl else "0"
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
flags = tf.flags.FLAGS




def preprocess(image, input_shape): def preprocess(image, input_shape):
@@ -89,7 +103,7 @@ class Estimator:


data_gen = DataGen(yolo_config, train_data.x) data_gen = DataGen(yolo_config, train_data.x)


max_epochs = int(kwargs.get("max_epochs", "1"))
max_epochs = int(kwargs.get("epochs", flags.max_epochs))
config = tf.ConfigProto(allow_soft_placement=True) config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True




+ 1
- 139
lib/sedna/algorithms/hard_example_mining/__init__.py View File

@@ -12,142 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.


"""Hard Example Mining Algorithms"""
import abc
import math

from sedna.common.class_factory import ClassFactory, ClassType

__all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter')


class BaseFilter(metaclass=abc.ABCMeta):
"""The base class to define unified interface."""

def __call__(self, infer_result=None):
"""predict function, and it must be implemented by
different methods class.

:param infer_result: prediction result
:return: `True` means hard sample, `False` means not a hard sample.
"""
raise NotImplementedError

@classmethod
def data_check(cls, data):
"""Check the data in [0,1]."""
return 0 <= float(data) <= 1


@ClassFactory.register(ClassType.HEM, alias="Threshold")
class ThresholdFilter(BaseFilter, abc.ABC):
def __init__(self, threshold=0.5, **kwargs):
self.threshold = float(threshold)

def __call__(self, infer_result=None):
"""
:param infer_result: [N, 6], (x0, y0, x1, y1, score, class)
:return: `True` means hard sample, `False` means not a hard sample.
"""
# if invalid input, return False
if not (infer_result
and all(map(lambda x: len(x) > 4, infer_result))):
return False

image_score = 0

for bbox in infer_result:
image_score += bbox[4]

average_score = image_score / (len(infer_result) or 1)
return average_score < self.threshold


@ClassFactory.register(ClassType.HEM, alias="CrossEntropy")
class CrossEntropyFilter(BaseFilter, abc.ABC):
""" Implement the hard samples discovery methods named IBT
(image-box-thresholds).

:param threshold_cross_entropy: threshold_cross_entropy to filter img,
whose hard coefficient is less than
threshold_cross_entropy. And its default value is
threshold_cross_entropy=0.5
"""

def __init__(self, threshold_cross_entropy=0.5, **kwargs):
self.threshold_cross_entropy = float(threshold_cross_entropy)

def __call__(self, infer_result=None):
"""judge the img is hard sample or not.

:param infer_result:
prediction classes list,
such as [class1-score, class2-score, class2-score,....],
where class-score is the score corresponding to the class,
class-score value is in [0,1], who will be ignored if its value
not in [0,1].
:return: `True` means a hard sample, `False` means not a hard sample.
"""

if not infer_result:
# if invalid input, return False
return False

log_sum = 0.0
data_check_list = [class_probability for class_probability
in infer_result
if self.data_check(class_probability)]

if len(data_check_list) != len(infer_result):
return False

for class_data in data_check_list:
log_sum += class_data * math.log(class_data)
confidence_score = 1 + 1.0 * log_sum / math.log(
len(infer_result))
return confidence_score < self.threshold_cross_entropy


@ClassFactory.register(ClassType.HEM, alias="IBT")
class IBTFilter(BaseFilter, abc.ABC):
"""Implement the hard samples discovery methods named IBT
(image-box-thresholds).

:param threshold_img: threshold_img to filter img, whose hard coefficient
is less than threshold_img.
:param threshold_box: threshold_box to calculate hard coefficient, formula
is hard coefficient = number(prediction_boxes less than
threshold_box)/number(prediction_boxes)
"""

def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs):
self.threshold_box = float(threshold_box)
self.threshold_img = float(threshold_img)

def __call__(self, infer_result=None):
"""Judge the img is hard sample or not.

:param infer_result:
prediction boxes list,
such as [bbox1, bbox2, bbox3,....],
where bbox = [xmin, ymin, xmax, ymax, score, label]
score should be in [0,1], who will be ignored if its value not
in [0,1].
:return: `True` means a hard sample, `False` means not a hard sample.
"""

if not (infer_result
and all(map(lambda x: len(x) > 4, infer_result))):
# if invalid input, return False
return False

data_check_list = [bbox[4] for bbox in infer_result
if self.data_check(bbox[4])]
if len(data_check_list) != len(infer_result):
return False

confidence_score_list = [
float(box_score) for box_score in data_check_list
if float(box_score) <= self.threshold_box]
return (len(confidence_score_list) / len(infer_result)
>= (1 - self.threshold_img))
from .hard_example_mining import *

+ 153
- 0
lib/sedna/algorithms/hard_example_mining/hard_example_mining.py View File

@@ -0,0 +1,153 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Hard Example Mining Algorithms"""
import abc
import math

from sedna.common.class_factory import ClassFactory, ClassType

__all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter')


class BaseFilter(metaclass=abc.ABCMeta):
"""The base class to define unified interface."""

def __call__(self, infer_result=None):
"""predict function, and it must be implemented by
different methods class.

:param infer_result: prediction result
:return: `True` means hard sample, `False` means not a hard sample.
"""
raise NotImplementedError

@classmethod
def data_check(cls, data):
"""Check the data in [0,1]."""
return 0 <= float(data) <= 1


@ClassFactory.register(ClassType.HEM, alias="Threshold")
class ThresholdFilter(BaseFilter, abc.ABC):
def __init__(self, threshold=0.5, **kwargs):
self.threshold = float(threshold)

def __call__(self, infer_result=None):
"""
:param infer_result: [N, 6], (x0, y0, x1, y1, score, class)
:return: `True` means hard sample, `False` means not a hard sample.
"""
# if invalid input, return False
if not (infer_result
and all(map(lambda x: len(x) > 4, infer_result))):
return False

image_score = 0

for bbox in infer_result:
image_score += bbox[4]

average_score = image_score / (len(infer_result) or 1)
return average_score < self.threshold


@ClassFactory.register(ClassType.HEM, alias="CrossEntropy")
class CrossEntropyFilter(BaseFilter, abc.ABC):
""" Implement the hard samples discovery methods named IBT
(image-box-thresholds).

:param threshold_cross_entropy: threshold_cross_entropy to filter img,
whose hard coefficient is less than
threshold_cross_entropy. And its default value is
threshold_cross_entropy=0.5
"""

def __init__(self, threshold_cross_entropy=0.5, **kwargs):
self.threshold_cross_entropy = float(threshold_cross_entropy)

def __call__(self, infer_result=None):
"""judge the img is hard sample or not.

:param infer_result:
prediction classes list,
such as [class1-score, class2-score, class2-score,....],
where class-score is the score corresponding to the class,
class-score value is in [0,1], who will be ignored if its value
not in [0,1].
:return: `True` means a hard sample, `False` means not a hard sample.
"""

if not infer_result:
# if invalid input, return False
return False

log_sum = 0.0
data_check_list = [class_probability for class_probability
in infer_result
if self.data_check(class_probability)]

if len(data_check_list) != len(infer_result):
return False

for class_data in data_check_list:
log_sum += class_data * math.log(class_data)
confidence_score = 1 + 1.0 * log_sum / math.log(
len(infer_result))
return confidence_score < self.threshold_cross_entropy


@ClassFactory.register(ClassType.HEM, alias="IBT")
class IBTFilter(BaseFilter, abc.ABC):
"""Implement the hard samples discovery methods named IBT
(image-box-thresholds).

:param threshold_img: threshold_img to filter img, whose hard coefficient
is less than threshold_img.
:param threshold_box: threshold_box to calculate hard coefficient, formula
is hard coefficient = number(prediction_boxes less than
threshold_box)/number(prediction_boxes)
"""

def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs):
self.threshold_box = float(threshold_box)
self.threshold_img = float(threshold_img)

def __call__(self, infer_result=None):
"""Judge the img is hard sample or not.

:param infer_result:
prediction boxes list,
such as [bbox1, bbox2, bbox3,....],
where bbox = [xmin, ymin, xmax, ymax, score, label]
score should be in [0,1], who will be ignored if its value not
in [0,1].
:return: `True` means a hard sample, `False` means not a hard sample.
"""

if not (infer_result
and all(map(lambda x: len(x) > 4, infer_result))):
# if invalid input, return False
return False

data_check_list = [bbox[4] for bbox in infer_result
if self.data_check(bbox[4])]
if len(data_check_list) != len(infer_result):
return False

confidence_score_list = [
float(box_score) for box_score in data_check_list
if float(box_score) <= self.threshold_box]
return (len(confidence_score_list) / len(infer_result)
>= (1 - self.threshold_img))

+ 1
- 1
lib/sedna/backend/__init__.py View File

@@ -48,7 +48,7 @@ def set_backend(estimator=None, config=None):
warnings.warn(f"{backend_type} Not Support yet, use itself") warnings.warn(f"{backend_type} Not Support yet, use itself")
from sedna.backend.base import BackendBase as REGISTER from sedna.backend.base import BackendBase as REGISTER
model_save_url = config.get("model_url") model_save_url = config.get("model_url")
base_model_save = config.get("base_model_save") or model_save_url
base_model_save = config.get("base_model_url") or model_save_url
model_save_name = config.get("model_name") model_save_name = config.get("model_name")
return REGISTER( return REGISTER(
estimator=estimator, use_cuda=use_cuda, estimator=estimator, use_cuda=use_cuda,


+ 3
- 2
lib/sedna/backend/tensorflow/__init__.py View File

@@ -36,8 +36,9 @@ class TFBackend(BackendBase):
super(TFBackend, self).__init__( super(TFBackend, self).__init__(
estimator=estimator, fine_tune=fine_tune, **kwargs) estimator=estimator, fine_tune=fine_tune, **kwargs)
self.framework = "tensorflow" self.framework = "tensorflow"
sess_config = self._init_gpu_session_config(
) if self.use_cuda else self._init_cpu_session_config()

sess_config = (self._init_gpu_session_config()
if self.use_cuda else self._init_cpu_session_config())
self.graph = tf.Graph() self.graph = tf.Graph()


with self.graph.as_default(): with self.graph.as_default():


+ 32
- 3
lib/sedna/core/incremental_learning/incremental_learning.py View File

@@ -28,9 +28,13 @@ class IncrementalLearning(JobBase):
Incremental learning Incremental learning
""" """


def __init__(self, estimator, config=None):
super(IncrementalLearning, self).__init__(
estimator=estimator, config=config)
def __init__(self, estimator):
"""
Initial a IncrementalLearning job
:param estimator: Customize estimator
"""

super(IncrementalLearning, self).__init__(estimator=estimator)


self.model_urls = self.get_parameters( self.model_urls = self.get_parameters(
"MODEL_URLS") # use in evaluation "MODEL_URLS") # use in evaluation
@@ -42,6 +46,15 @@ class IncrementalLearning(JobBase):
valid_data=None, valid_data=None,
post_process=None, post_process=None,
**kwargs): **kwargs):
"""
Training task for IncrementalLearning
:param train_data: datasource use for train
:param valid_data: datasource use for evaluation
:param post_process: post process
:param kwargs: params for training of customize estimator
:return: estimator
"""

callback_func = None callback_func = None
if post_process is not None: if post_process is not None:
callback_func = ClassFactory.get_cls( callback_func = ClassFactory.get_cls(
@@ -58,6 +71,14 @@ class IncrementalLearning(JobBase):
self.estimator) if callback_func else self.estimator self.estimator) if callback_func else self.estimator


def inference(self, data=None, post_process=None, **kwargs): def inference(self, data=None, post_process=None, **kwargs):
"""
Inference task for IncrementalLearning
:param data: inference sample
:param post_process: post process
:param kwargs: params for inference of customize estimator
:return: inference result, result after post_process, if is hard sample
"""

if not self.estimator.has_load: if not self.estimator.has_load:
self.estimator.load(self.model_path) self.estimator.load(self.model_path)


@@ -81,6 +102,14 @@ class IncrementalLearning(JobBase):
return infer_res, res, is_hard_example return infer_res, res, is_hard_example


def evaluate(self, data, post_process=None, **kwargs): def evaluate(self, data, post_process=None, **kwargs):
"""
Evaluate task for IncrementalLearning
:param data: datasource use for evaluation
:param post_process: post process
:param kwargs: params for evaluate of customize estimator
:return: evaluate metrics
"""

callback_func = None callback_func = None
if callable(post_process): if callable(post_process):
callback_func = post_process callback_func = post_process


Loading…
Cancel
Save