You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

worker.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. from distutils import util
  16. import os
  17. import time
  18. import cv2
  19. from urllib.request import Request, urlopen
  20. from estimator import str_to_estimator_class
  21. from sedna.common.log import LOGGER
  22. from sedna.core.multi_edge_tracking.utils import get_parameters
  23. from sedna.datasources.obs.connector import OBSClientWrapper
  24. from sedna.core.multi_edge_tracking.components.detector import ObjectDetector
  25. class Bootstrapper():
  26. def __init__(self) -> None:
  27. LOGGER.info("Creating Detection/Tracking Bootstrapper module")
  28. self.estimator_class = get_parameters('estimator_class', "ByteTracker")
  29. self.hostname = get_parameters('hostname', "unknown")
  30. self.fps = float(get_parameters('fps', 25))
  31. self.batch_size = int(get_parameters('batch_size', 1))
  32. self.video_id = get_parameters('video_id', 0)
  33. self.video_address = get_parameters('video_address', "")
  34. self.eclass = str_to_estimator_class(estimator_class=self.estimator_class)
  35. self.enable_obs = bool(util.strtobool(get_parameters('ENABLE_OBS', "False")))
  36. if self.enable_obs:
  37. self.obs_client = OBSClientWrapper(app_token=get_parameters('OBS_TOKEN', ''))
  38. self.service = None
  39. def run(self):
  40. protocol = self.video_address.split(":")[0]
  41. LOGGER.info(f"Detected video source protocol {protocol} for video source {self.video_address}.")
  42. # TODO: This is not reliable. For example, HLS won't work (https://en.wikipedia.org/wiki/HTTP_Live_Streaming).
  43. if protocol in ["rtmp", "rtsp"]: #stream
  44. self.process_video_from_stream()
  45. elif protocol in ["http"]: #cdn
  46. filename = self.download_video(protocol)
  47. self.process_video_from_disk(filename)
  48. elif os.path.isfile(self.video_address): #file from disk (preloaded)
  49. self.process_video_from_disk(self.video_address)
  50. else: # file from obs?
  51. filename = self.obs_client.download_single_object(self.video_address, ".")
  52. if filename:
  53. self.process_video_from_disk(filename)
  54. else:
  55. LOGGER.error(f"Unable to open {self.video_address}.")
  56. self.close()
  57. def download_video(self):
  58. try:
  59. req = Request(self.video_address, headers={'User-Agent': 'Mozilla/5.0'})
  60. LOGGER.info("Video download complete")
  61. filename = f'{self.video_id}.mp4'
  62. with open(filename,'wb') as f:
  63. f.write(urlopen(req).read())
  64. return filename
  65. except Exception as ex:
  66. LOGGER.error(f"Unable to download video file {ex}")
  67. def connect_to_camera(self, stream_address):
  68. camera = None
  69. while camera == None or not camera.isOpened():
  70. try:
  71. camera = cv2.VideoCapture(stream_address)
  72. camera.set(cv2.CAP_PROP_BUFFERSIZE, 0)
  73. except Exception as ex:
  74. LOGGER.error(f'Unable to open video source: [{ex}]')
  75. time.sleep(1)
  76. return camera
  77. def process_video_from_disk(self, filename, timeout=20):
  78. selected_estimator=self.eclass(video_id=self.video_id)
  79. self.service = ObjectDetector(models=[selected_estimator])
  80. cap = cv2.VideoCapture(filename)
  81. index = 0
  82. while(cap.isOpened()):
  83. # Capture frame-by-frame
  84. ret, frame = cap.read()
  85. if ret == True:
  86. LOGGER.debug(f"Current frame index is {index}.")
  87. img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  88. det_time = datetime.datetime.now().strftime("%a, %d %B %Y %H:%M:%S.%f")
  89. self.service.put((img_rgb, det_time, index))
  90. index +=1
  91. else:
  92. break
  93. # When everything done, release the video capture object
  94. cap.release()
  95. def process_video_from_stream(self, timeout=20):
  96. selected_estimator=self.eclass(video_id=self.video_id)
  97. self.service=ObjectDetector(models=[selected_estimator], asynchronous=True)
  98. nframe = 0
  99. grabbed = False
  100. last_snapshot = time.time()
  101. camera = self.connect_to_camera(self.video_address)
  102. while (camera.isOpened()):
  103. grabbed = camera.grab()
  104. if grabbed:
  105. if ((time.time() - last_snapshot) >= 1/self.fps):
  106. LOGGER.debug(f"Current frame index is {nframe}.")
  107. ret, frame = camera.retrieve()
  108. if ret:
  109. cv2.cvtColor(src=frame, code=cv2.COLOR_BGR2RGB, dst=frame)
  110. det_time = datetime.datetime.now().strftime("%a, %d %B %Y %H:%M:%S.%f")
  111. self.service.put(data=(frame, det_time, nframe))
  112. last_snapshot = time.time()
  113. nframe += 1
  114. elif (time.time() - last_snapshot) >= timeout:
  115. LOGGER.debug(f"Timeout reached, releasing video source.")
  116. camera.release()
  117. def close(self, timeout=20):
  118. while (time.time() - self.service.heartbeat) <= timeout:
  119. LOGGER.debug(f"Waiting for more data from the feature extraction service..")
  120. time.sleep(1)
  121. #perform cleanup of the service
  122. self.service.close()
  123. LOGGER.info(f"VideoAnalysis job completed.")
  124. if __name__ == '__main__':
  125. bs = Bootstrapper()
  126. bs.run()