You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

joint_inference.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import abc
  2. import json
  3. import logging
  4. import os
  5. import threading
  6. import time
  7. import cv2
  8. import numpy as np
  9. import requests
  10. import tensorflow as tf
  11. from PIL import Image
  12. from flask import Flask, request
  13. import neptune
  14. from neptune.common.config import BaseConfig
  15. from neptune.common.constant import K8sResourceKind
  16. from neptune.hard_example_mining import CrossEntropyFilter, IBTFilter, \
  17. ThresholdFilter
  18. from neptune.joint_inference.data import ServiceInfo
  19. from neptune.lc_client import LCClient
  20. LOG = logging.getLogger(__name__)
  21. class BigModelConfig(BaseConfig):
  22. def __init__(self):
  23. BaseConfig.__init__(self)
  24. self.bind_ip = os.getenv("BIG_MODEL_BIND_IP", "0.0.0.0")
  25. self.bind_port = (
  26. int(os.getenv("BIG_MODEL_BIND_PORT", "5000"))
  27. )
  28. class LittleModelConfig(BaseConfig):
  29. def __init__(self):
  30. BaseConfig.__init__(self)
  31. class BigModelClientConfig:
  32. def __init__(self):
  33. self.ip = os.getenv("BIG_MODEL_IP")
  34. self.port = int(os.getenv("BIG_MODEL_PORT", "5000"))
  35. class BaseModel:
  36. """Model abstract class.
  37. :param preprocess: function before inference
  38. :param postprocess: function after inference
  39. :param input_shape: input shape
  40. :param create_input_feed: the function of creating input feed
  41. :param create_output_fetch: the function fo creating output fetch
  42. """
  43. def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
  44. create_input_feed=None, create_output_fetch=None):
  45. self.preprocess = preprocess
  46. self.postprocess = postprocess
  47. self.input_shape = input_shape
  48. if create_input_feed is None or create_output_fetch is None:
  49. raise RuntimeError("Please offer create_input_feed "
  50. "and create_output_fetch function")
  51. self.create_input_feed = create_input_feed
  52. self.create_output_fetch = create_output_fetch
  53. @abc.abstractmethod
  54. def _load_model(self):
  55. pass
  56. @abc.abstractmethod
  57. def inference(self, img_data):
  58. pass
  59. class BigModelClient:
  60. """Remote big model service, which interacts with the cloud big model."""
  61. _retry = 5
  62. _retry_interval_seconds = 1
  63. def __init__(self):
  64. self.config = BigModelClientConfig()
  65. self.big_model_endpoint = "http://{0}:{1}".format(
  66. self.config.ip,
  67. self.config.port
  68. )
  69. def _load_model(self):
  70. pass
  71. def inference(self, img_data):
  72. """Use the remote big model server to inference."""
  73. _, encoded_image = cv2.imencode(".jpeg", img_data)
  74. files = {"images": encoded_image}
  75. error = None
  76. for i in range(BigModelClient._retry):
  77. try:
  78. res = requests.post(self.big_model_endpoint, timeout=5,
  79. files=files)
  80. if res.status_code < 300:
  81. return res.json().get("data")
  82. else:
  83. LOG.error(f"send request to {self.big_model_endpoint} "
  84. f"failed, status is {res.status_code}")
  85. return None
  86. except requests.exceptions.RequestException as e:
  87. error = e
  88. time.sleep(BigModelClient._retry_interval_seconds)
  89. LOG.error(f"send request to {self.big_model_endpoint} failed, "
  90. f"error is {error}, retry times: {BigModelClient._retry}")
  91. return None
  92. class TSBigModelService(BaseModel):
  93. """Large model services implemented based on TensorFlow.
  94. Provides RESTful interfaces for large-model inference.
  95. """
  96. def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
  97. create_input_feed=None, create_output_fetch=None):
  98. BaseModel.__init__(self, preprocess, postprocess, input_shape,
  99. create_input_feed, create_output_fetch)
  100. self.config = BigModelConfig()
  101. self.input_shape = input_shape
  102. self._load_model()
  103. self.app = Flask(__name__)
  104. self.register()
  105. self.app.run(host=self.config.bind_ip,
  106. port=self.config.bind_port)
  107. def register(self):
  108. @self.app.route('/', methods=['POST'])
  109. def inference():
  110. f = request.files.get('images')
  111. image = Image.open(f)
  112. image = image.convert("RGB")
  113. img_data, org_img_shape = self.preprocess(image, self.input_shape)
  114. data = self.inference(img_data)
  115. result = self.postprocess(data, org_img_shape)
  116. # encapsulate the user result
  117. data = {"data": result}
  118. return json.dumps(data)
  119. def _load_model(self):
  120. self.graph = tf.Graph()
  121. self.sess = tf.compat.v1.InteractiveSession(graph=self.graph)
  122. with tf.io.gfile.GFile(self.config.model_url, 'rb') as f:
  123. graph_def = tf.compat.v1.GraphDef()
  124. graph_def.ParseFromString(f.read())
  125. tf.import_graph_def(graph_def, name='')
  126. LOG.info(f"Import yolo model from {self.config.model_url} end .......")
  127. def inference(self, img_data):
  128. input_feed = self.create_input_feed(self.sess, img_data)
  129. output_fetch = self.create_output_fetch(self.sess)
  130. return self.sess.run(output_fetch, input_feed)
  131. class TSLittleModel(BaseModel):
  132. """Little model services implemented based on TensorFlow.
  133. Provides RESTful interfaces for large-model inference.
  134. """
  135. def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
  136. create_input_feed=None, create_output_fetch=None):
  137. BaseModel.__init__(self, preprocess, postprocess, input_shape,
  138. create_input_feed, create_output_fetch)
  139. self.config = LittleModelConfig()
  140. graph = tf.Graph()
  141. config = tf.ConfigProto(allow_soft_placement=True)
  142. config.gpu_options.allow_growth = True
  143. config.gpu_options.per_process_gpu_memory_fraction = 0.1
  144. self.session = tf.Session(graph=graph, config=config)
  145. self._load_model()
  146. def _load_model(self):
  147. with self.session.as_default():
  148. with self.session.graph.as_default():
  149. with tf.gfile.FastGFile(self.config.model_url, 'rb') as handle:
  150. LOG.info(f"Load model {self.config.model_url}, "
  151. f"ParseFromString start .......")
  152. graph_def = tf.GraphDef()
  153. graph_def.ParseFromString(handle.read())
  154. LOG.info("ParseFromString end .......")
  155. tf.import_graph_def(graph_def, name='')
  156. LOG.info("Import_graph_def end .......")
  157. LOG.info("Import model from pb end .......")
  158. def inference(self, img_data):
  159. img_data_np = np.array(img_data)
  160. with self.session.as_default():
  161. new_image = self.preprocess(img_data_np, self.input_shape)
  162. input_feed = self.create_input_feed(self.session, new_image,
  163. img_data_np)
  164. output_fetch = self.create_output_fetch(self.session)
  165. output = self.session.run(output_fetch, input_feed)
  166. if self.postprocess:
  167. output = self.postprocess(output)
  168. return output
  169. class LCReporter(threading.Thread):
  170. """Inherited thread, which is an entity that periodically report to
  171. the lc.
  172. """
  173. def __init__(self):
  174. threading.Thread.__init__(self)
  175. # the value of statistics
  176. self.inference_number = 0
  177. self.hard_example_number = 0
  178. self.period_interval = int(os.getenv("LC_PERIOD", "30"))
  179. # The system resets the period_increment after sending the messages to
  180. # the LC. If the period_increment is 0 in the current period,
  181. # the system does not send the messages to the LC.
  182. self.period_increment = 0
  183. self.lock = threading.Lock()
  184. def update_for_edge_inference(self):
  185. self.lock.acquire()
  186. self.inference_number += 1
  187. self.period_increment += 1
  188. self.lock.release()
  189. def update_for_collaboration_inference(self):
  190. self.lock.acquire()
  191. self.inference_number += 1
  192. self.hard_example_number += 1
  193. self.period_increment += 1
  194. self.lock.release()
  195. def run(self):
  196. while True:
  197. info = ServiceInfo()
  198. info.startTime = time.strftime("%Y-%m-%d %H:%M:%S",
  199. time.localtime())
  200. time.sleep(self.period_interval)
  201. if self.period_increment == 0:
  202. LOG.debug("period increment is zero, skip report")
  203. continue
  204. info.updateTime = time.strftime("%Y-%m-%d %H:%M:%S",
  205. time.localtime())
  206. info.inferenceNumber = self.inference_number
  207. info.hardExampleNumber = self.hard_example_number
  208. info.uploadCloudRatio = (
  209. self.hard_example_number / self.inference_number
  210. )
  211. message = {
  212. "name": BaseConfig.worker_name,
  213. "namespace": BaseConfig.namespace,
  214. "ownerName": BaseConfig.service_name,
  215. "ownerKind": K8sResourceKind.JOINT_INFERENCE_SERVICE.value,
  216. "kind": "inference",
  217. "ownerInfo": info.__dict__,
  218. "results": []
  219. }
  220. LCClient.send(BaseConfig.worker_name, message)
  221. self.period_increment = 0
  222. class InferenceResult:
  223. """The Result class for joint inference
  224. :param is_hard_example: `True` means a hard example, `False` means not a
  225. hard example
  226. :param final_result: the final inference result
  227. :param hard_example_edge_result: the edge little model inference result of
  228. hard example
  229. :param hard_example_cloud_result: the cloud big model inference result of
  230. hard example
  231. """
  232. def __init__(self, is_hard_example, final_result,
  233. hard_example_edge_result, hard_example_cloud_result):
  234. self.is_hard_example = is_hard_example
  235. self.final_result = final_result
  236. self.hard_example_edge_result = hard_example_edge_result
  237. self.hard_example_cloud_result = hard_example_cloud_result
  238. class JointInference:
  239. """Class provided for external systems for model joint inference.
  240. :param little_model: the little model entity for edge inference
  241. :param hard_example_mining_algorithm: the algorithm for judging hard
  242. example
  243. """
  244. def __init__(self, little_model: BaseModel,
  245. hard_example_mining_algorithm=None):
  246. self.little_model = little_model
  247. self.big_model = BigModelClient()
  248. # TODO how to deal process use-defined cloud_offload_algorithm,
  249. # especially parameters
  250. if hard_example_mining_algorithm is None:
  251. hem_name = BaseConfig.hem_name
  252. if hem_name == "IBT":
  253. threshold_box = float(neptune.context.get_hem_parameters(
  254. "threshold_box", 0.5
  255. ))
  256. threshold_img = float(neptune.context.get_hem_parameters(
  257. "threshold_img", 0.5
  258. ))
  259. hard_example_mining_algorithm = IBTFilter(threshold_img,
  260. threshold_box)
  261. elif hem_name == "CrossEntropy":
  262. threshold_cross_entropy = float(
  263. neptune.context.get_hem_parameters(
  264. "threshold_cross_entropy", 0.5
  265. )
  266. )
  267. hard_example_mining_algorithm = CrossEntropyFilter(
  268. threshold_cross_entropy)
  269. else:
  270. hard_example_mining_algorithm = ThresholdFilter()
  271. self.hard_example_mining_algorithm = hard_example_mining_algorithm
  272. self.lc_reporter = LCReporter()
  273. self.lc_reporter.setDaemon(True)
  274. self.lc_reporter.start()
  275. def inference(self, img_data) -> InferenceResult:
  276. """Image inference function."""
  277. img_data_pre = img_data
  278. edge_result = self.little_model.inference(img_data_pre)
  279. is_hard_example = self.hard_example_mining_algorithm.hard_judge(
  280. edge_result
  281. )
  282. if not is_hard_example:
  283. LOG.debug("not hard example, use edge result directly")
  284. self.lc_reporter.update_for_edge_inference()
  285. return InferenceResult(False, edge_result, None, None)
  286. cloud_result = self._cloud_inference(img_data)
  287. if cloud_result is None:
  288. LOG.warning("retrieve cloud infer service failed, use edge result")
  289. self.lc_reporter.update_for_edge_inference()
  290. return InferenceResult(True, edge_result, edge_result, None)
  291. else:
  292. LOG.debug(f"retrieve cloud infer service success, use cloud "
  293. f"result, cloud result:{cloud_result}")
  294. self.lc_reporter.update_for_collaboration_inference()
  295. return InferenceResult(True, cloud_result, edge_result,
  296. cloud_result)
  297. def _cloud_inference(self, img_rgb):
  298. return self.big_model.inference(img_rgb)
  299. def _get_or_default(parameter, default):
  300. value = neptune.context.get_parameters(parameter)
  301. return value if value else default