You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

predict.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import os
  2. import cv2
  3. import time
  4. import numpy as np
  5. from PIL import Image
  6. import warnings
  7. from sedna.datasources import BaseDataSource, TxtDataParse
  8. from basemodel import Model, preprocess_frames
  9. from sedna.core.lifelong_learning import LifelongLearning
  10. from sedna.common.config import Context
  11. def preprocess(samples):
  12. data = BaseDataSource(data_type="test")
  13. data.x = [samples]
  14. return data
  15. def postprocess(samples):
  16. image_names, imgs = [], []
  17. for sample in samples:
  18. img = sample.get("image")
  19. image_names.append("{}.png".format(str(time.time())))
  20. imgs.append(img)
  21. return image_names, imgs
  22. def _load_txt_dataset(dataset_url):
  23. # use original dataset url
  24. original_dataset_url = Context.get_parameters('original_dataset_url', "")
  25. dataset_urls = dataset_url.split()
  26. dataset_urls = [
  27. os.path.join(
  28. os.path.dirname(original_dataset_url),
  29. dataset_url) for dataset_url in dataset_urls]
  30. return dataset_urls[:-1], dataset_urls[-1]
  31. def init_ll_job(**kwargs):
  32. estimator = Model(num_class=31,
  33. weight_path=kwargs.get('weight_path'),
  34. save_predicted_image=True,
  35. merge=True)
  36. task_allocation = {
  37. "method": "TaskAllocationDefault"
  38. }
  39. unseen_task_allocation = {
  40. "method": "UnseenTaskAllocationDefault"
  41. }
  42. unseen_sample_recognition = {
  43. "method": "OodIdentification",
  44. "param": {
  45. "OOD_thresh": float(kwargs.get("OOD_thresh")),
  46. "backup_model": kwargs.get("OOD_backup_model"),
  47. "OOD_model_path": kwargs.get("OOD_model"),
  48. "preprocess_func": preprocess_frames,
  49. "base_model": Model
  50. }
  51. }
  52. # unseen_sample_recognition = {
  53. # "method": "SampleRegonitionRobotic"
  54. # }
  55. inference_integrate = {
  56. "method": "InferenceIntegrateByType"
  57. }
  58. ll_job = LifelongLearning(
  59. estimator,
  60. unseen_estimator=unseen_task_processing,
  61. task_definition=None,
  62. task_relationship_discovery=None,
  63. task_allocation=task_allocation,
  64. task_remodeling=None,
  65. inference_integrate=inference_integrate,
  66. task_update_decision=None,
  67. unseen_task_allocation=unseen_task_allocation,
  68. unseen_sample_recognition=unseen_sample_recognition,
  69. unseen_sample_re_recognition=None)
  70. return ll_job
  71. def unseen_task_processing():
  72. return "Warning: unseen sample detected."
  73. def predict():
  74. ll_job = init_ll_job()
  75. camera_address = Context.get_parameters('video_url')
  76. # use video streams for testing
  77. camera = cv2.VideoCapture(camera_address)
  78. fps = 10
  79. nframe = 0
  80. while 1:
  81. ret, input_yuv = camera.read()
  82. if not ret:
  83. time.sleep(5)
  84. camera = cv2.VideoCapture(camera_address)
  85. continue
  86. if nframe % fps:
  87. nframe += 1
  88. continue
  89. img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB)
  90. nframe += 1
  91. if nframe % 1000 == 1: # logs every 1000 frames
  92. warnings.warn(f"camera is open, current frame index is {nframe}")
  93. img_rgb = cv2.resize(np.array(img_rgb), (2048, 1024),
  94. interpolation=cv2.INTER_CUBIC)
  95. img_rgb = Image.fromarray(img_rgb)
  96. data = {'image': img_rgb, "depth": img_rgb, "label": img_rgb}
  97. data = preprocess(data)
  98. print(postprocess)
  99. print("Inference results:", ll_job.inference(
  100. data=data, post_process=postprocess))
  101. def predict_batch():
  102. ll_job = init_ll_job()
  103. test_dataset_url = Context.get_parameters("test_dataset_url")
  104. test_data = TxtDataParse(data_type="test", func=_load_txt_dataset)
  105. test_data.parse(test_dataset_url, use_raw=False)
  106. return ll_job.inference(data=test_data)
  107. if __name__ == '__main__':
  108. print(predict())