Browse Source

Add lib code for joint inference

Signed-off-by: khalid-davis <huangqinkai1@huawei.com>
tags/v0.1.0
khalid-davis llhuii 4 years ago
parent
commit
e5d59903cf
20 changed files with 726 additions and 0 deletions
  1. +0
    -0
      lib/__init__.py
  2. +16
    -0
      lib/neptune/__init__.py
  3. +0
    -0
      lib/neptune/common/__init__.py
  4. +23
    -0
      lib/neptune/common/config.py
  5. +23
    -0
      lib/neptune/common/constant.py
  6. +53
    -0
      lib/neptune/context.py
  7. +3
    -0
      lib/neptune/hard_example_mining/__init__.py
  8. +31
    -0
      lib/neptune/hard_example_mining/base.py
  9. +1
    -0
      lib/neptune/hard_example_mining/hard_example_helpers/__init__.py
  10. +8
    -0
      lib/neptune/hard_example_mining/hard_example_helpers/data_check_utils.py
  11. +0
    -0
      lib/neptune/hard_example_mining/image_classification/__init__.py
  12. +53
    -0
      lib/neptune/hard_example_mining/image_classification/hard_mine_filters.py
  13. +0
    -0
      lib/neptune/hard_example_mining/object_detection/__init__.py
  14. +56
    -0
      lib/neptune/hard_example_mining/object_detection/scores_filters.py
  15. +1
    -0
      lib/neptune/joint_inference/__init__.py
  16. +20
    -0
      lib/neptune/joint_inference/data.py
  17. +370
    -0
      lib/neptune/joint_inference/joint_inference.py
  18. +43
    -0
      lib/neptune/lc_client.py
  19. +6
    -0
      lib/requirements.txt
  20. +19
    -0
      lib/setup.py

+ 0
- 0
lib/__init__.py View File


+ 16
- 0
lib/neptune/__init__.py View File

@@ -0,0 +1,16 @@
import logging

from . import joint_inference
from .context import context


def log_configure():
logging.basicConfig(
format='[%(asctime)s][%(name)s][%(levelname)s][%(lineno)s]: '
'%(message)s',
level=logging.INFO)


LOG = logging.getLogger(__name__)

log_configure()

+ 0
- 0
lib/neptune/common/__init__.py View File


+ 23
- 0
lib/neptune/common/config.py View File

@@ -0,0 +1,23 @@
import os


class BaseConfig:
"""The base config, the value can not be changed."""
# dataset
train_dataset_url = os.getenv("TRAIN_DATASET_URL")
test_dataset_url = os.getenv("TEST_DATASET_URL")
# k8s crd info
namespace = os.getenv("NAMESPACE", "")
worker_name = os.getenv("WORKER_NAME", "")
service_name = os.getenv("SERVICE_NAME", "")

model_url = os.getenv("MODEL_URL")

# user parameter
parameters = os.getenv("PARAMETERS")
# Hard Example Mining Algorithm
hem_name = os.getenv("HEM_NAME")
hem_parameters = os.getenv("HEM_PARAMETERS")

def __init__(self):
pass

+ 23
- 0
lib/neptune/common/constant.py View File

@@ -0,0 +1,23 @@
import logging
from enum import Enum

LOG = logging.getLogger(__name__)


class Framework(Enum):
Tensorflow = "tensorflow"
Pytorch = "pytorch"
Mindspore = "mindspore"


class K8sResourceKind(Enum):
JOINT_INFERENCE_SERVICE = "jointinferenceservice"


class K8sResourceKindStatus(Enum):
COMPLETED = "completed"
FAILED = "failed"
RUNNING = "running"


FRAMEWORK = Framework.Tensorflow # TODO: should read from env.

+ 53
- 0
lib/neptune/context.py View File

@@ -0,0 +1,53 @@
import json
import logging

from neptune.common.config import BaseConfig

LOG = logging.getLogger(__name__)


def parse_parameters(parameters):
"""
:param parameters:
[{"key":"batch_size","value":"32"},
{"key":"learning_rate","value":"0.001"},
{"key":"min_node_number","value":"3"}]
---->
:return:
{'batch_size':32, 'learning_rate':0.001, 'min_node_number'=3}
"""
p = {}
if parameters is None or len(parameters) == 0:
LOG.info(f"PARAMETERS={parameters}, return empty dict.")
return p
j = json.loads(parameters)
for d in j:
p[d.get('key')] = d.get('value')
return p


class Context:
"""The Context provides the capability of obtaining the context of the
`PARAMETERS` and `HEM_PARAMETERS` field"""

def __init__(self):
self.parameters = parse_parameters(BaseConfig.parameters)
self.hem_parameters = parse_parameters(BaseConfig.hem_parameters)

def get_context(self):
return self.parameters

def get_parameters(self, param, default=None):
"""get the value of the key `param` in `PARAMETERS`,
if not exist, the default value is returned"""
value = self.parameters.get(param)
return value if value else default

def get_hem_parameters(self, param, default=None):
"""get the value of the key `param` in `HEM_PARAMETERS`,
if not exist, the default value is returned"""
value = self.hem_parameters.get(param)
return value if value else default


context = Context()

+ 3
- 0
lib/neptune/hard_example_mining/__init__.py View File

@@ -0,0 +1,3 @@
from .base import BaseFilter, ThresholdFilter
from .image_classification.hard_mine_filters import CrossEntropyFilter
from .object_detection.scores_filters import IBTFilter

+ 31
- 0
lib/neptune/hard_example_mining/base.py View File

@@ -0,0 +1,31 @@
class BaseFilter:
"""The base class to define unified interface."""
def hard_judge(self, infer_result=None):
"""predict function, and it must be implemented by
different methods class.
:param infer_result: prediction result
:return: `True` means hard sample, `False` means not a hard sample.
"""
raise NotImplementedError
class ThresholdFilter(BaseFilter):
def __init__(self, threshold=0.5):
self.threshold = threshold
def hard_judge(self, infer_result=None):
"""
:param infer_result: [N, 6], (x0, y0, x1, y1, score, class)
:return: `True` means hard sample, `False` means not a hard sample.
"""
if not infer_result:
return True
image_score = 0
for bbox in infer_result:
image_score += bbox[4]
average_score = image_score / (len(infer_result) or 1)
return average_score < self.threshold

+ 1
- 0
lib/neptune/hard_example_mining/hard_example_helpers/__init__.py View File

@@ -0,0 +1 @@
from .data_check_utils import data_check

+ 8
- 0
lib/neptune/hard_example_mining/hard_example_helpers/data_check_utils.py View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
def data_check(data):
"""Check the data in [0,1]."""
return 0 <= float(data) <= 1

+ 0
- 0
lib/neptune/hard_example_mining/image_classification/__init__.py View File


+ 53
- 0
lib/neptune/hard_example_mining/image_classification/hard_mine_filters.py View File

@@ -0,0 +1,53 @@
import logging
import math
from neptune.hard_example_mining import BaseFilter
from neptune.hard_example_mining.hard_example_helpers import data_check
logger = logging.getLogger(__name__)
class CrossEntropyFilter(BaseFilter):
""" Implement the hard samples discovery methods named IBT
(image-box-thresholds).
:param threshold_cross_entropy: threshold_cross_entropy to filter img,
whose hard coefficient is less than
threshold_cross_entropy. And its default value is
threshold_cross_entropy=0.5
"""
def __init__(self, threshold_cross_entropy=0.5):
self.threshold_cross_entropy = threshold_cross_entropy
def hard_judge(self, infer_result=None):
"""judge the img is hard sample or not.
:param infer_result:
prediction classes list,
such as [class1-score, class2-score, class2-score,....],
where class-score is the score corresponding to the class,
class-score value is in [0,1], who will be ignored if its value
not in [0,1].
:return: `True` means a hard sample, `False` means not a hard sample.
"""
if infer_result is None:
logger.warning(f'infer result is invalid, value: {infer_result}!')
return False
elif len(infer_result) == 0:
return False
else:
log_sum = 0.0
data_check_list = [class_probability for class_probability
in infer_result
if data_check(class_probability)]
if len(data_check_list) == len(infer_result):
for class_data in data_check_list:
log_sum += class_data * math.log(class_data)
confidence_score = 1 + 1.0 * log_sum / math.log(
len(infer_result))
return confidence_score >= self.threshold_cross_entropy
else:
logger.warning("every value of infer_result should be in "
f"[0,1], your data is {infer_result}")
return False

+ 0
- 0
lib/neptune/hard_example_mining/object_detection/__init__.py View File


+ 56
- 0
lib/neptune/hard_example_mining/object_detection/scores_filters.py View File

@@ -0,0 +1,56 @@
import logging
from neptune.hard_example_mining import BaseFilter
from neptune.hard_example_mining.hard_example_helpers import data_check
logger = logging.getLogger(__name__)
class IBTFilter(BaseFilter):
"""Implement the hard samples discovery methods named IBT
(image-box-thresholds).
:param threshold_img: threshold_img to filter img, whose hard coefficient
is less than threshold_img.
:param threshold_box: threshold_box to calculate hard coefficient, formula
is hard coefficient = number(prediction_boxes less than
threshold_box)/number(prediction_boxes)
"""
def __init__(self, threshold_img=0.5, threshold_box=0.5):
self.threshold_box = threshold_box
self.threshold_img = threshold_img
def hard_judge(self, infer_result=None):
"""Judge the img is hard sample or not.
:param infer_result:
prediction boxes list,
such as [bbox1, bbox2, bbox3,....],
where bbox = [xmin, ymin, xmax, ymax, score, label]
score should be in [0,1], who will be ignored if its value not
in [0,1].
:return: `True` means a hard sample, `False` means not a hard sample.
"""
if infer_result is None:
logger.warning(f'infer result is invalid, value: {infer_result}!')
return False
elif len(infer_result) == 0:
return False
else:
data_check_list = [bbox[4] for bbox in infer_result
if data_check(bbox[4])]
if len(data_check_list) == len(infer_result):
confidence_score_list = [
float(box_score) for box_score in data_check_list
if float(box_score) <= self.threshold_box]
if (len(confidence_score_list) / len(infer_result)) \
>= (1 - self.threshold_img):
return True
else:
return False
else:
logger.warning(
"every value of infer_result should be in [0,1], "
f"your data is {infer_result}")
return False

+ 1
- 0
lib/neptune/joint_inference/__init__.py View File

@@ -0,0 +1 @@
from .joint_inference import *

+ 20
- 0
lib/neptune/joint_inference/data.py View File

@@ -0,0 +1,20 @@
import json


class ServiceInfo:
def __init__(self):
self.startTime = ''
self.updateTime = ''
self.inferenceNumber = 0
self.hardExampleNumber = 0
self.uploadCloudRatio = 0

@staticmethod
def from_json(json_str):
info = ServiceInfo()
info.__dict__ = json.loads(json_str)
return info

def to_json(self):
info = json.dumps(self.__dict__)
return info

+ 370
- 0
lib/neptune/joint_inference/joint_inference.py View File

@@ -0,0 +1,370 @@
import abc
import json
import logging
import os
import threading
import time

import cv2
import numpy as np
import requests
import tensorflow as tf
from PIL import Image
from flask import Flask, request

import neptune
from neptune.common.config import BaseConfig
from neptune.common.constant import K8sResourceKind
from neptune.hard_example_mining import CrossEntropyFilter, IBTFilter, \
ThresholdFilter
from neptune.joint_inference.data import ServiceInfo
from neptune.lc_client import LCClient

LOG = logging.getLogger(__name__)


class BigModelConfig(BaseConfig):
def __init__(self):
BaseConfig.__init__(self)
self.bind_ip = os.getenv("BIG_MODEL_BIND_IP", "0.0.0.0")
self.bind_port = (
int(os.getenv("BIG_MODEL_BIND_PORT", "5000"))
)


class LittleModelConfig(BaseConfig):
def __init__(self):
BaseConfig.__init__(self)


class BigModelClientConfig:
def __init__(self):
self.ip = os.getenv("BIG_MODEL_IP")
self.port = int(os.getenv("BIG_MODEL_PORT", "5000"))


class BaseModel:
"""Model abstract class.

:param preprocess: function before inference
:param postprocess: function after inference
:param input_shape: input shape
:param create_input_feed: the function of creating input feed
:param create_output_fetch: the function fo creating output fetch
"""

def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
create_input_feed=None, create_output_fetch=None):
self.preprocess = preprocess
self.postprocess = postprocess
self.input_shape = input_shape
if create_input_feed is None or create_output_fetch is None:
raise RuntimeError("Please offer create_input_feed "
"and create_output_fetch function")
self.create_input_feed = create_input_feed
self.create_output_fetch = create_output_fetch

@abc.abstractmethod
def _load_model(self):
pass

@abc.abstractmethod
def inference(self, img_data):
pass


class BigModelClient:
"""Remote big model service, which interacts with the cloud big model."""
_retry = 5
_retry_interval_seconds = 1

def __init__(self):
self.config = BigModelClientConfig()
self.big_model_endpoint = "http://{0}:{1}".format(
self.config.ip,
self.config.port
)

def _load_model(self):
pass

def inference(self, img_data):
"""Use the remote big model server to inference."""
_, encoded_image = cv2.imencode(".jpeg", img_data)
files = {"images": encoded_image}
error = None
for i in range(BigModelClient._retry):
try:
res = requests.post(self.big_model_endpoint, timeout=5,
files=files)
if res.status_code < 300:
return res.json().get("data")
else:
LOG.error(f"send request to {self.big_model_endpoint} "
f"failed, status is {res.status_code}")
return None
except requests.exceptions.RequestException as e:
error = e
time.sleep(BigModelClient._retry_interval_seconds)

LOG.error(f"send request to {self.big_model_endpoint} failed, "
f"error is {error}, retry times: {BigModelClient._retry}")
return None


class TSBigModelService(BaseModel):
"""Large model services implemented based on TensorFlow.
Provides RESTful interfaces for large-model inference.
"""

def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
create_input_feed=None, create_output_fetch=None):
BaseModel.__init__(self, preprocess, postprocess, input_shape,
create_input_feed, create_output_fetch)
self.config = BigModelConfig()

self.input_shape = input_shape
self._load_model()

self.app = Flask(__name__)
self.register()
self.app.run(host=self.config.bind_ip,
port=self.config.bind_port)

def register(self):
@self.app.route('/', methods=['POST'])
def inference():
f = request.files.get('images')
image = Image.open(f)
image = image.convert("RGB")
img_data, org_img_shape = self.preprocess(image, self.input_shape)
data = self.inference(img_data)
result = self.postprocess(data, org_img_shape)
# encapsulate the user result
data = {"data": result}
return json.dumps(data)

def _load_model(self):
self.graph = tf.Graph()
self.sess = tf.compat.v1.InteractiveSession(graph=self.graph)

with tf.io.gfile.GFile(self.config.model_url, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())

tf.import_graph_def(graph_def, name='')
LOG.info(f"Import yolo model from {self.config.model_url} end .......")

def inference(self, img_data):
input_feed = self.create_input_feed(self.sess, img_data)
output_fetch = self.create_output_fetch(self.sess)

return self.sess.run(output_fetch, input_feed)


class TSLittleModel(BaseModel):
"""Little model services implemented based on TensorFlow.
Provides RESTful interfaces for large-model inference.
"""

def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
create_input_feed=None, create_output_fetch=None):
BaseModel.__init__(self, preprocess, postprocess, input_shape,
create_input_feed, create_output_fetch)

self.config = LittleModelConfig()

graph = tf.Graph()
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.1
self.session = tf.Session(graph=graph, config=config)
self._load_model()

def _load_model(self):
with self.session.as_default():
with self.session.graph.as_default():
with tf.gfile.FastGFile(self.config.model_url, 'rb') as handle:
LOG.info(f"Load model {self.config.model_url}, "
f"ParseFromString start .......")
graph_def = tf.GraphDef()
graph_def.ParseFromString(handle.read())
LOG.info("ParseFromString end .......")

tf.import_graph_def(graph_def, name='')
LOG.info("Import_graph_def end .......")

LOG.info("Import model from pb end .......")

def inference(self, img_data):
img_data_np = np.array(img_data)
with self.session.as_default():
new_image = self.preprocess(img_data_np, self.input_shape)
input_feed = self.create_input_feed(self.session, new_image,
img_data_np)
output_fetch = self.create_output_fetch(self.session)
return self.session.run(output_fetch, input_feed)


class LCReporter(threading.Thread):
"""Inherited thread, which is an entity that periodically report to
the lc.
"""

def __init__(self):
threading.Thread.__init__(self)

# the value of statistics
self.inference_number = 0
self.hard_example_number = 0
self.period_interval = int(os.getenv("LC_PERIOD", "30"))
# The system resets the period_increment after sending the messages to
# the LC. If the period_increment is 0 in the current period,
# the system does not send the messages to the LC.
self.period_increment = 0
self.lock = threading.Lock()

def update_for_edge_inference(self):
self.lock.acquire()
self.inference_number += 1
self.period_increment += 1
self.lock.release()

def update_for_collaboration_inference(self):
self.lock.acquire()
self.inference_number += 1
self.hard_example_number += 1
self.period_increment += 1
self.lock.release()

def run(self):
while True:

info = ServiceInfo()
info.startTime = time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime())

time.sleep(self.period_interval)
if self.period_increment == 0:
LOG.debug("period increment is zero, skip report")
continue
info.updateTime = time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime())
info.inferenceNumber = self.inference_number
info.hardExampleNumber = self.hard_example_number
info.uploadCloudRatio = (
self.hard_example_number / self.inference_number
)
message = {
"name": BaseConfig.worker_name,
"namespace": BaseConfig.namespace,
"ownerName": BaseConfig.service_name,
"ownerKind": K8sResourceKind.JOINT_INFERENCE_SERVICE.value,
"kind": "inference",
"ownerInfo": info.__dict__,
"results": []
}
LCClient.send(BaseConfig.worker_name, message)
self.period_increment = 0


class InferenceResult:
"""The Result class for joint inference

:param is_hard_sample: `True` means a hard sample, `False` means not a hard
sample
:param final_result: the final inference result
:param hard_sample_edge_result: the edge little model inference result of
hard sample
:param hard_sample_cloud_result: the cloud big model inference result of
hard sample
"""

def __init__(self, is_hard_sample, final_result,
hard_sample_edge_result, hard_sample_cloud_result):
self.is_hard_sample = is_hard_sample
self.final_result = final_result
self.hard_sample_edge_result = hard_sample_edge_result
self.hard_sample_cloud_result = hard_sample_cloud_result


class JointInference:
"""Class provided for external systems for model joint inference.

:param little_model: the little model entity for edge inference
:param hard_example_mining_algorithm: the algorithm for judging hard sample
:param pre_hook: the pre function of edge inference
:param post_hook: the post function of edge inference
"""

def __init__(self, little_model: BaseModel,
hard_example_mining_algorithm=None,
pre_hook=None, post_hook=None):
self.little_model = little_model
self.big_model = BigModelClient()
# TODO how to deal process use-defined cloud_offload_algorithm,
# especially parameters
if hard_example_mining_algorithm is None:
hem_name = BaseConfig.hem_name

if hem_name == "IBT":
threshold_box = float(neptune.context.get_hem_parameters(
"threshold_box", 0.5
))
threshold_img = float(neptune.context.get_hem_parameters(
"threshold_img", 0.5
))
hard_example_mining_algorithm = IBTFilter(threshold_img,
threshold_box)
elif hem_name == "CrossEntropy":
threshold_cross_entropy = float(
neptune.context.get_hem_parameters(
"threshold_cross_entropy", 0.5
)
)
hard_example_mining_algorithm = CrossEntropyFilter(
threshold_cross_entropy)
else:
hard_example_mining_algorithm = ThresholdFilter()

self.cloud_offload_algorithm = hard_example_mining_algorithm
self.pre_hook = pre_hook
self.post_hook = post_hook

self.lc_reporter = LCReporter()
self.lc_reporter.setDaemon(True)
self.lc_reporter.start()

def inference(self, img_data) -> InferenceResult:
"""Image inference function."""
img_data_pre = img_data
if self.pre_hook:
img_data_pre = self.pre_hook(img_data_pre)
edge_result = self.little_model.inference(img_data_pre)
if self.post_hook:
edge_result = self.post_hook(edge_result)
is_hard_sample = self.cloud_offload_algorithm.hard_judge(edge_result)
if not is_hard_sample:
LOG.debug("not hard sample, use edge result directly")
self.lc_reporter.update_for_edge_inference()
return InferenceResult(False, edge_result, None, None)

cloud_result = self._cloud_inference(img_data)
if cloud_result is None:
LOG.warning("retrieve cloud infer service failed, use edge result")
self.lc_reporter.update_for_edge_inference()
return InferenceResult(True, edge_result, edge_result, None)
else:
LOG.debug(f"retrieve cloud infer service success, use cloud "
f"result, cloud result:{cloud_result}")
self.lc_reporter.update_for_collaboration_inference()
return InferenceResult(True, cloud_result, edge_result,
cloud_result)

def _cloud_inference(self, img_rgb):
return self.big_model.inference(img_rgb)


def _get_or_default(parameter, default):
value = neptune.context.get_parameters(parameter)
return value if value else default

+ 43
- 0
lib/neptune/lc_client.py View File

@@ -0,0 +1,43 @@
import logging
import os
import time

import requests

LOG = logging.getLogger(__name__)


class LCClientConfig:
def __init__(self):
self.lc_server = os.getenv("LC_SERVER", "http://127.0.0.1:9100")


class LCClient:
_retry = 3
_retry_interval_seconds = 0.5
config = LCClientConfig()

@classmethod
def send(cls, worker_name, message: dict):

url = '{0}/neptune/workers/{1}/info'.format(
cls.config.lc_server,
worker_name
)
error = None
for i in range(cls._retry):
try:
res = requests.post(url=url, json=message)
LOG.info(
f"send to lc, url={url}, data={message},"
f"state={res.status_code}")
return res.status_code < 300
except Exception as e:
error = e
time.sleep(cls._retry_interval_seconds)

LOG.warning(
f"can't connect to lc[{cls.config.lc_server}] "
f"data={message}, error={error}, "
f"retry times: {cls._retry}")
return False

+ 6
- 0
lib/requirements.txt View File

@@ -0,0 +1,6 @@
flask==1.1.2
keras==2.4.3
opencv-python==4.4.0.44
websockets==8.1
Pillow==8.0.1
requests==2.24.0

+ 19
- 0
lib/setup.py View File

@@ -0,0 +1,19 @@
from setuptools import setup

setup(
name='neptune',
version='0.0.1',
description="The neptune package is designed to help developers \
better use open source frameworks such as tensorflow \
on Neptune project",
packages=['neptune'],
python_requires='>=3.6',
install_requires=[
'flask>=1.1.2',
'keras>=2.4.3',
'Pillow>=8.0.1',
'opencv-python>=4.4.0.44',
'websockets>=8.1'
'requests>=2.24.0'
]
)

Loading…
Cancel
Save