|
- import sys
- import os
- import random
- import numpy as np
- sys.path.append("/home/shanwei-luo/userdata/mmdetection")
- from mmdet.apis import (async_inference_detector, inference_detector,
- init_detector, show_result_pyplot)
- import argparse
- #python select_threshold.py --config_file /home/shanwei-luo/teamdata/anomaly_detection_active_learning/model/work_dirs/AD_dsxw_test66_06_10/AD_dsxw_test66_06_10.py --checkpoint_file /home/shanwei-luo/teamdata/anomaly_detection_active_learning/model/work_dirs/AD_dsxw_test66_06_10/latest.pth --images_path /home/shanwei-luo/teamdata/anomaly_detection_active_learning/data0422/smd12_2106_10/test/ --test_batch_size 128
- def parse_args():
- parser = argparse.ArgumentParser(description='get best threshold')
- parser.add_argument('--config_file', help='config')
- parser.add_argument('--checkpoint_file', help='checkpoint')
- parser.add_argument('--images_path', help='images')
- parser.add_argument('--test_batch_size', help='images')
- args = parser.parse_args()
- return args
-
- args = parse_args()
-
- class_AOI_name = {"bu_pi_pei":"1","fang_xiang_fan":"2","err.txt_c_not_f":"3", "shang_xi_bu_lia":"4"}
-
- config_file_1 = args.config_file
- checkpoint_file_1 = args.checkpoint_file
-
- imgs_ok_path = args.images_path+'ok/'
- imgs_ng_path = args.images_path+'ng/'
-
- model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:0')
-
- imgs_ok = os.listdir(imgs_ok_path)
- imgs_ng = os.listdir(imgs_ng_path)
-
- count_label_ok = len(imgs_ok)
- count_label_ng = len(imgs_ng)
- print(count_label_ok,count_label_ng)
- imgs_labels = []
- imgs_name = []
- for img in imgs_ok:
- img_name = img.split("@")
- if img_name[2] in class_AOI_name.keys():
- count_label_ok -= 1
- continue
- imgs_labels.append(0)
- imgs_name.append(imgs_ok_path+img)
-
- for img in imgs_ng:
- img_name = img.split("@")
- if img_name[2] in class_AOI_name.keys():
- count_label_ng -= 1
- continue
- imgs_labels.append(1)
- imgs_name.append(imgs_ng_path+img)
-
- print(count_label_ok,count_label_ng, len(imgs_labels))
- print("before infer")
- index = 0
- num = len(imgs_name)
- results_1 = []
- step = int(args.test_batch_size)
- while index<num:
- index += step
- if index < num:
- results_1_tmp = inference_detector(model_1, imgs_name[index-step:index])
- else:
- results_1_tmp = inference_detector(model_1, imgs_name[index-step:num])
- results_1 += results_1_tmp
- print(len(results_1))
- print("after infer")
-
-
- recall_ok_ng = 0
- best_thr = 0.005
- for score_thr in np.arange(0.01, 0.3, 0.005):
- imgs_results_1 = []
- for result in results_1:
- res_predict = 0
- for i in result:
- for j in range(i.shape[0]):
- if i[j, 4]>score_thr:
- res_predict = 1
- imgs_results_1.append(res_predict)
-
- count_ng = 0
- count_ok = 0
- for i in range(len(imgs_labels)):
- if imgs_labels[i]==0 and imgs_results_1[i]==0:
- count_ok += 1
- if imgs_labels[i]==1 and imgs_results_1[i]==1:
- count_ng += 1
- recall_ok = count_ok/count_label_ok
- recall_ng = count_ng/count_label_ng
- print("score_thr:", score_thr, " recall(ok):", recall_ok, " recall(ng):", recall_ng)
-
- if recall_ok_ng < 0.3*recall_ok + 0.7*recall_ng:
- recall_ok_ng = 0.3*recall_ok + 0.7*recall_ng
- best_thr = score_thr
- print("***********************")
- print("best threshold:", best_thr)
|