You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 993 B

12345678910111213141516171819202122232425262728293031
  1. class BaseFilter:
  2. """The base class to define unified interface."""
  3. def hard_judge(self, infer_result=None):
  4. """predict function, and it must be implemented by
  5. different methods class.
  6. :param infer_result: prediction result
  7. :return: `True` means hard sample, `False` means not a hard sample.
  8. """
  9. raise NotImplementedError
  10. class ThresholdFilter(BaseFilter):
  11. def __init__(self, threshold=0.5):
  12. self.threshold = threshold
  13. def hard_judge(self, infer_result=None):
  14. """
  15. :param infer_result: [N, 6], (x0, y0, x1, y1, score, class)
  16. :return: `True` means hard sample, `False` means not a hard sample.
  17. """
  18. if not infer_result:
  19. return True
  20. image_score = 0
  21. for bbox in infer_result:
  22. image_score += bbox[4]
  23. average_score = image_score / (len(infer_result) or 1)
  24. return average_score < self.threshold