You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 2.7 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YoloV5 310 infer."""
  16. import os
  17. import time
  18. import numpy as np
  19. from pycocotools.coco import COCO
  20. from src.logger import get_logger
  21. from src.util import DetectionEngine
  22. from model_utils.config import config
  23. if __name__ == "__main__":
  24. start_time = time.time()
  25. config.output_dir = config.log_path
  26. config.logger = get_logger(config.output_dir, 0)
  27. # init detection engine
  28. detection = DetectionEngine(config, config.test_ignore_threshold)
  29. coco = COCO(config.ann_file)
  30. result_path = config.result_files
  31. files = os.listdir(config.dataset_path)
  32. for file in files:
  33. img_ids_name = file.split('.')[0]
  34. img_id_ = int(np.squeeze(img_ids_name))
  35. imgIds = coco.getImgIds(imgIds=[img_id_])
  36. img = coco.loadImgs(imgIds[np.random.randint(0, len(imgIds))])[0]
  37. image_shape = ((img['width'], img['height']),)
  38. img_id_ = (np.squeeze(img_ids_name),)
  39. result_path_0 = os.path.join(result_path, img_ids_name + "_0.bin")
  40. result_path_1 = os.path.join(result_path, img_ids_name + "_1.bin")
  41. result_path_2 = os.path.join(result_path, img_ids_name + "_2.bin")
  42. output_small = np.fromfile(result_path_0, dtype=np.float32).reshape(1, 20, 20, 3, 85)
  43. output_me = np.fromfile(result_path_1, dtype=np.float32).reshape(1, 40, 40, 3, 85)
  44. output_big = np.fromfile(result_path_2, dtype=np.float32).reshape(1, 80, 80, 3, 85)
  45. detection.detect([output_small, output_me, output_big], config.batch_size, image_shape, img_id_)
  46. config.logger.info('Calculating mAP...')
  47. detection.do_nms_for_results()
  48. result_file_path = detection.write_result()
  49. config.logger.info('result file path: %s', result_file_path)
  50. eval_result = detection.get_eval_result()
  51. cost_time = time.time() - start_time
  52. config.logger.info('=============coco 310 infer reulst=========')
  53. config.logger.info(eval_result)
  54. config.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.))

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。