You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

joint_inference.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import abc
  2. import json
  3. import logging
  4. import os
  5. import threading
  6. import time
  7. import cv2
  8. import numpy as np
  9. import requests
  10. import tensorflow as tf
  11. from PIL import Image
  12. from flask import Flask, request
  13. import neptune
  14. from neptune.common.config import BaseConfig
  15. from neptune.common.constant import K8sResourceKind
  16. from neptune.hard_example_mining import CrossEntropyFilter, IBTFilter, \
  17. ThresholdFilter
  18. from neptune.joint_inference.data import ServiceInfo
  19. from neptune.lc_client import LCClient
  20. LOG = logging.getLogger(__name__)
  21. class BigModelConfig(BaseConfig):
  22. def __init__(self):
  23. BaseConfig.__init__(self)
  24. self.bind_ip = os.getenv("BIG_MODEL_BIND_IP", "0.0.0.0")
  25. self.bind_port = (
  26. int(os.getenv("BIG_MODEL_BIND_PORT", "5000"))
  27. )
  28. class LittleModelConfig(BaseConfig):
  29. def __init__(self):
  30. BaseConfig.__init__(self)
  31. class BigModelClientConfig:
  32. def __init__(self):
  33. self.ip = os.getenv("BIG_MODEL_IP")
  34. self.port = int(os.getenv("BIG_MODEL_PORT", "5000"))
  35. class BaseModel:
  36. """Model abstract class.
  37. :param preprocess: function before inference
  38. :param postprocess: function after inference
  39. :param input_shape: input shape
  40. :param create_input_feed: the function of creating input feed
  41. :param create_output_fetch: the function fo creating output fetch
  42. """
  43. def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
  44. create_input_feed=None, create_output_fetch=None):
  45. self.preprocess = preprocess
  46. self.postprocess = postprocess
  47. self.input_shape = input_shape
  48. if create_input_feed is None or create_output_fetch is None:
  49. raise RuntimeError("Please offer create_input_feed "
  50. "and create_output_fetch function")
  51. self.create_input_feed = create_input_feed
  52. self.create_output_fetch = create_output_fetch
  53. @abc.abstractmethod
  54. def _load_model(self):
  55. pass
  56. @abc.abstractmethod
  57. def inference(self, img_data):
  58. pass
  59. class BigModelClient:
  60. """Remote big model service, which interacts with the cloud big model."""
  61. _retry = 5
  62. _retry_interval_seconds = 1
  63. def __init__(self):
  64. self.config = BigModelClientConfig()
  65. self.big_model_endpoint = "http://{0}:{1}".format(
  66. self.config.ip,
  67. self.config.port
  68. )
  69. def _load_model(self):
  70. pass
  71. def inference(self, img_data):
  72. """Use the remote big model server to inference."""
  73. _, encoded_image = cv2.imencode(".jpeg", img_data)
  74. files = {"images": encoded_image}
  75. error = None
  76. for i in range(BigModelClient._retry):
  77. try:
  78. res = requests.post(self.big_model_endpoint, timeout=5,
  79. files=files)
  80. if res.status_code < 300:
  81. return res.json().get("data")
  82. else:
  83. LOG.error(f"send request to {self.big_model_endpoint} "
  84. f"failed, status is {res.status_code}")
  85. return None
  86. except requests.exceptions.RequestException as e:
  87. error = e
  88. time.sleep(BigModelClient._retry_interval_seconds)
  89. LOG.error(f"send request to {self.big_model_endpoint} failed, "
  90. f"error is {error}, retry times: {BigModelClient._retry}")
  91. return None
  92. class TSBigModelService(BaseModel):
  93. """Large model services implemented based on TensorFlow.
  94. Provides RESTful interfaces for large-model inference.
  95. """
  96. def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
  97. create_input_feed=None, create_output_fetch=None):
  98. BaseModel.__init__(self, preprocess, postprocess, input_shape,
  99. create_input_feed, create_output_fetch)
  100. self.config = BigModelConfig()
  101. self.input_shape = input_shape
  102. self._load_model()
  103. self.app = Flask(__name__)
  104. self.register()
  105. self.app.run(host=self.config.bind_ip,
  106. port=self.config.bind_port)
  107. def register(self):
  108. @self.app.route('/', methods=['POST'])
  109. def inference():
  110. f = request.files.get('images')
  111. image = Image.open(f)
  112. image = image.convert("RGB")
  113. img_data, org_img_shape = self.preprocess(image, self.input_shape)
  114. data = self.inference(img_data)
  115. result = self.postprocess(data, org_img_shape)
  116. # encapsulate the user result
  117. data = {"data": result}
  118. return json.dumps(data)
  119. def _load_model(self):
  120. self.graph = tf.Graph()
  121. self.sess = tf.compat.v1.InteractiveSession(graph=self.graph)
  122. with tf.io.gfile.GFile(self.config.model_url, 'rb') as f:
  123. graph_def = tf.compat.v1.GraphDef()
  124. graph_def.ParseFromString(f.read())
  125. tf.import_graph_def(graph_def, name='')
  126. LOG.info(f"Import yolo model from {self.config.model_url} end .......")
  127. def inference(self, img_data):
  128. input_feed = self.create_input_feed(self.sess, img_data)
  129. output_fetch = self.create_output_fetch(self.sess)
  130. return self.sess.run(output_fetch, input_feed)
  131. class TSLittleModel(BaseModel):
  132. """Little model services implemented based on TensorFlow.
  133. Provides RESTful interfaces for large-model inference.
  134. """
  135. def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
  136. create_input_feed=None, create_output_fetch=None):
  137. BaseModel.__init__(self, preprocess, postprocess, input_shape,
  138. create_input_feed, create_output_fetch)
  139. self.config = LittleModelConfig()
  140. graph = tf.Graph()
  141. config = tf.ConfigProto(allow_soft_placement=True)
  142. config.gpu_options.allow_growth = True
  143. config.gpu_options.per_process_gpu_memory_fraction = 0.1
  144. self.session = tf.Session(graph=graph, config=config)
  145. self._load_model()
  146. def _load_model(self):
  147. with self.session.as_default():
  148. with self.session.graph.as_default():
  149. with tf.gfile.FastGFile(self.config.model_url, 'rb') as handle:
  150. LOG.info(f"Load model {self.config.model_url}, "
  151. f"ParseFromString start .......")
  152. graph_def = tf.GraphDef()
  153. graph_def.ParseFromString(handle.read())
  154. LOG.info("ParseFromString end .......")
  155. tf.import_graph_def(graph_def, name='')
  156. LOG.info("Import_graph_def end .......")
  157. LOG.info("Import model from pb end .......")
  158. def inference(self, img_data):
  159. img_data_np = np.array(img_data)
  160. with self.session.as_default():
  161. new_image = self.preprocess(img_data_np, self.input_shape)
  162. input_feed = self.create_input_feed(self.session, new_image,
  163. img_data_np)
  164. output_fetch = self.create_output_fetch(self.session)
  165. return self.session.run(output_fetch, input_feed)
  166. class LCReporter(threading.Thread):
  167. """Inherited thread, which is an entity that periodically report to
  168. the lc.
  169. """
  170. def __init__(self):
  171. threading.Thread.__init__(self)
  172. # the value of statistics
  173. self.inference_number = 0
  174. self.hard_example_number = 0
  175. self.period_interval = int(os.getenv("LC_PERIOD", "30"))
  176. # The system resets the period_increment after sending the messages to
  177. # the LC. If the period_increment is 0 in the current period,
  178. # the system does not send the messages to the LC.
  179. self.period_increment = 0
  180. self.lock = threading.Lock()
  181. def update_for_edge_inference(self):
  182. self.lock.acquire()
  183. self.inference_number += 1
  184. self.period_increment += 1
  185. self.lock.release()
  186. def update_for_collaboration_inference(self):
  187. self.lock.acquire()
  188. self.inference_number += 1
  189. self.hard_example_number += 1
  190. self.period_increment += 1
  191. self.lock.release()
  192. def run(self):
  193. while True:
  194. info = ServiceInfo()
  195. info.startTime = time.strftime("%Y-%m-%d %H:%M:%S",
  196. time.localtime())
  197. time.sleep(self.period_interval)
  198. if self.period_increment == 0:
  199. LOG.debug("period increment is zero, skip report")
  200. continue
  201. info.updateTime = time.strftime("%Y-%m-%d %H:%M:%S",
  202. time.localtime())
  203. info.inferenceNumber = self.inference_number
  204. info.hardExampleNumber = self.hard_example_number
  205. info.uploadCloudRatio = (
  206. self.hard_example_number / self.inference_number
  207. )
  208. message = {
  209. "name": BaseConfig.worker_name,
  210. "namespace": BaseConfig.namespace,
  211. "ownerName": BaseConfig.service_name,
  212. "ownerKind": K8sResourceKind.JOINT_INFERENCE_SERVICE.value,
  213. "kind": "inference",
  214. "ownerInfo": info.__dict__,
  215. "results": []
  216. }
  217. LCClient.send(BaseConfig.worker_name, message)
  218. self.period_increment = 0
  219. class InferenceResult:
  220. """The Result class for joint inference
  221. :param is_hard_sample: `True` means a hard sample, `False` means not a hard
  222. sample
  223. :param final_result: the final inference result
  224. :param hard_sample_edge_result: the edge little model inference result of
  225. hard sample
  226. :param hard_sample_cloud_result: the cloud big model inference result of
  227. hard sample
  228. """
  229. def __init__(self, is_hard_sample, final_result,
  230. hard_sample_edge_result, hard_sample_cloud_result):
  231. self.is_hard_sample = is_hard_sample
  232. self.final_result = final_result
  233. self.hard_sample_edge_result = hard_sample_edge_result
  234. self.hard_sample_cloud_result = hard_sample_cloud_result
  235. class JointInference:
  236. """Class provided for external systems for model joint inference.
  237. :param little_model: the little model entity for edge inference
  238. :param hard_example_mining_algorithm: the algorithm for judging hard sample
  239. :param pre_hook: the pre function of edge inference
  240. :param post_hook: the post function of edge inference
  241. """
  242. def __init__(self, little_model: BaseModel,
  243. hard_example_mining_algorithm=None,
  244. pre_hook=None, post_hook=None):
  245. self.little_model = little_model
  246. self.big_model = BigModelClient()
  247. # TODO how to deal process use-defined cloud_offload_algorithm,
  248. # especially parameters
  249. if hard_example_mining_algorithm is None:
  250. hem_name = BaseConfig.hem_name
  251. if hem_name == "IBT":
  252. threshold_box = float(neptune.context.get_hem_parameters(
  253. "threshold_box", 0.5
  254. ))
  255. threshold_img = float(neptune.context.get_hem_parameters(
  256. "threshold_img", 0.5
  257. ))
  258. hard_example_mining_algorithm = IBTFilter(threshold_img,
  259. threshold_box)
  260. elif hem_name == "CrossEntropy":
  261. threshold_cross_entropy = float(
  262. neptune.context.get_hem_parameters(
  263. "threshold_cross_entropy", 0.5
  264. )
  265. )
  266. hard_example_mining_algorithm = CrossEntropyFilter(
  267. threshold_cross_entropy)
  268. else:
  269. hard_example_mining_algorithm = ThresholdFilter()
  270. self.cloud_offload_algorithm = hard_example_mining_algorithm
  271. self.pre_hook = pre_hook
  272. self.post_hook = post_hook
  273. self.lc_reporter = LCReporter()
  274. self.lc_reporter.setDaemon(True)
  275. self.lc_reporter.start()
  276. def inference(self, img_data) -> InferenceResult:
  277. """Image inference function."""
  278. img_data_pre = img_data
  279. if self.pre_hook:
  280. img_data_pre = self.pre_hook(img_data_pre)
  281. edge_result = self.little_model.inference(img_data_pre)
  282. if self.post_hook:
  283. edge_result = self.post_hook(edge_result)
  284. is_hard_sample = self.cloud_offload_algorithm.hard_judge(edge_result)
  285. if not is_hard_sample:
  286. LOG.debug("not hard sample, use edge result directly")
  287. self.lc_reporter.update_for_edge_inference()
  288. return InferenceResult(False, edge_result, None, None)
  289. cloud_result = self._cloud_inference(img_data)
  290. if cloud_result is None:
  291. LOG.warning("retrieve cloud infer service failed, use edge result")
  292. self.lc_reporter.update_for_edge_inference()
  293. return InferenceResult(True, edge_result, edge_result, None)
  294. else:
  295. LOG.debug(f"retrieve cloud infer service success, use cloud "
  296. f"result, cloud result:{cloud_result}")
  297. self.lc_reporter.update_for_collaboration_inference()
  298. return InferenceResult(True, cloud_result, edge_result,
  299. cloud_result)
  300. def _cloud_inference(self, img_rgb):
  301. return self.big_model.inference(img_rgb)
  302. def _get_or_default(parameter, default):
  303. value = neptune.context.get_parameters(parameter)
  304. return value if value else default