You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

worker.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from distutils import util
  15. import pathlib
  16. import time
  17. import torch
  18. import numpy as np
  19. from PIL import Image
  20. from threading import Thread
  21. from typing import List
  22. from store_result import save_image
  23. from sedna.algorithms.reid.multi_img_matching import match_query_to_targets
  24. from sedna.algorithms.reid.close_contact_estimation import ContactTracker
  25. from sedna.common.log import LOGGER
  26. from sedna.core.multi_edge_inference.components.reid import ReID
  27. from sedna.core.multi_edge_inference.data_classes import DetTrackResult, OP_MODE, Target
  28. from sedna.core.multi_edge_inference.utils import get_parameters
  29. from sedna.datasources.obs.connector import OBSClientWrapper
  30. MOUNT_PATH="/data/network_shared/reid"
  31. class ReIDWorker():
  32. def __init__(self, **kwargs):
  33. # Service parameters
  34. self.op_mode = OP_MODE(get_parameters('op_mode', 'covid19'))
  35. self.threshold = get_parameters('match_threshold', 0.75)
  36. self.user_id = get_parameters('user_id', "DEFAULT")
  37. self.query_images = str(get_parameters('query_images', f"{MOUNT_PATH}/query/sample.png")).split("|")
  38. self.target = None
  39. self.targets_list : List[Target] = []
  40. self.results_base_folder = f"{MOUNT_PATH}/images/"
  41. self.CT = ContactTracker(draw_top_view=False)
  42. self.enable_obs = bool(util.strtobool(get_parameters('ENABLE_OBS', "False")))
  43. if self.enable_obs:
  44. self.obs_client = OBSClientWrapper(app_token=get_parameters('OBS_TOKEN', ''))
  45. super(ReIDWorker, self).__init__()
  46. def update_plugin(self, status):
  47. # Update target
  48. if self.op_mode != OP_MODE.DETECTION:
  49. LOGGER.info("Loading target query images")
  50. # The target collection is a list of targets/userid that might grow overtime
  51. img_arr = []
  52. for image in self.query_images:
  53. img_arr.append(np.asarray(Image.open(image)))
  54. data = DetTrackResult(0, img_arr, None, [], 0, 0)
  55. data.userID = self.user_id
  56. return [data]
  57. def update_target(self, ldata):
  58. """
  59. Updates the target for the ReID.
  60. """
  61. LOGGER.info(f"Target updated for user {ldata[0].userid} with {len(ldata[0].features)} feature vectors!")
  62. self.targets_list = ldata
  63. def reid_per_frame(self, features, candidate_feats: torch.Tensor) -> int:
  64. """
  65. For each frame, this function receives the ReID features for all the detected boxes. The similarity is computed
  66. between the query features and the candidate features (from the boxes). If matching score for all detected boxes
  67. is less than match_thresh, the function returns None signifying that no match has been found. Else, the function
  68. returns the index of the candidate feature with the highest matching score.
  69. @param candidate_feats: ...
  70. @return: match_id [int] which points to the index of the matched detection.
  71. """
  72. if features == None:
  73. LOGGER.warning("Target has not been set!")
  74. return -1
  75. match_id, match_score = match_query_to_targets(features, candidate_feats, False)
  76. return match_id, match_score
  77. def predict(self, data, **kwargs):
  78. """Implements the on-the-fly ReID where detections per frame are matched with the candidate boxes."""
  79. tresult = []
  80. for dettrack_obj in data:
  81. try:
  82. reid_result = getattr(self, self.op_mode.value + "_no_gallery")(dettrack_obj)
  83. if reid_result is not None:
  84. tresult.append(reid_result)
  85. self.store_result(reid_result)
  86. except AttributeError as ex:
  87. LOGGER.error(f"Error in dynamic function mapping. [{ex}]")
  88. return tresult
  89. return tresult
  90. ### OP_MODE FUNCTIONS ###
  91. def covid19_no_gallery(self, det_track):
  92. return self.tracking_no_gallery(det_track)
  93. def detection_no_gallery(self, det_track):
  94. LOGGER.warning(f"This operational mode ({self.op_mode}) is not supported without gallery.")
  95. return None
  96. def tracking_no_gallery(self, det_track : DetTrackResult):
  97. """
  98. Performs ReID without gallery using the results from the
  99. tracking and feature extraction component.
  100. """
  101. det_track.targetID = [-1] * len(det_track.bbox_coord)
  102. for target in self.targets_list:
  103. # get id of highest match for each userid
  104. match_id, match_score = self.reid_per_frame(target.features, det_track.features)
  105. result = {
  106. "userID": str(target.userid),
  107. "match_id": str(match_score),
  108. "detection_area": det_track.camera,
  109. "detection_time": det_track.detection_time
  110. }
  111. if float(match_score) >= self.threshold:
  112. det_track.targetID[match_id]= str(target.userid)
  113. det_track.userID = target.userid
  114. det_track.is_target = match_id
  115. LOGGER.info(result)
  116. if det_track.targetID.count(-1) == len(det_track.targetID):
  117. # No target found, we don't send any result back
  118. return None
  119. return det_track
  120. def store_result(self, det_track : DetTrackResult):
  121. """
  122. Stores ReID result on disk (and OBS, if enabled).
  123. """
  124. try:
  125. filename = save_image(det_track, self.CT, folder=f"{self.results_base_folder}{det_track.userID}/")
  126. if self.enable_obs:
  127. self.obs_client.upload_file(f"{self.results_base_folder}{det_track.userID}/", filename, f"/media/reid/{det_track.userID}")
  128. except Exception as ex:
  129. LOGGER.error(f"Unable to save image: {ex}")
  130. class Bootstrapper(Thread):
  131. def __init__(self):
  132. super().__init__()
  133. self.daemon = True
  134. self.retry = 3
  135. self.job = ReID(models=[ReIDWorker()], asynchronous=False)
  136. def run(self) -> None:
  137. LOGGER.info("Loading data from disk.")
  138. while self.retry > 0:
  139. files = self.job.get_files_list(f"{MOUNT_PATH}/")
  140. if files:
  141. LOGGER.debug(f"Loaded {len(files)} files.")
  142. for filename in files:
  143. if pathlib.Path(filename).suffix == '.dat':
  144. data = self.job.read_from_disk(filename)
  145. if data:
  146. LOGGER.debug(f"File {filename} loaded!")
  147. self.job.put(data)
  148. self.job.delete_from_disk(filename)
  149. break
  150. else:
  151. LOGGER.warning("No data available to process!")
  152. self.retry-=1
  153. time.sleep(5)
  154. LOGGER.info("ReID job completed.")
  155. # Start the ReID job.
  156. if __name__ == '__main__':
  157. bs = Bootstrapper()
  158. bs.run()