You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

detect.py 15 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import argparse
  2. import os
  3. import platform
  4. import shutil
  5. import time
  6. from pathlib import Path
  7. import cv2
  8. import torch
  9. import torch.backends.cudnn as cudnn
  10. from numpy import random
  11. from utils.utils import *
  12. from models.experimental import attempt_load
  13. from utils.datasets import LoadStreams, LoadImages
  14. from utils.general import (
  15. check_img_size, non_max_suppression, apply_classifier, scale_coords,
  16. xyxy2xywh, plot_one_box, strip_optimizer, set_logging)
  17. from utils.torch_utils import select_device, load_classifier, time_synchronized
  18. from utils.general import (
  19. check_img_size, non_max_suppression, apply_classifier, scale_coords,
  20. xyxy2xywh, plot_one_box, strip_optimizer, set_logging)
  21. # def detect_image(source,out,imgsz = 640,save_img=False,save_txt = False,weights = "./weights/yolov5s.pt"):
  22. # # out, source, weights, view_img, save_txt, imgsz = \
  23. # # opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  24. # # webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')
  25. # webcam = source =='0'
  26. # # Initialize
  27. # set_logging()
  28. # device = select_device('')
  29. # # if os.path.exists(out):
  30. # # shutil.rmtree(out) # delete output folder
  31. # # os.mkdir(out) # make new output folder
  32. # half = device.type != 'cpu' # half precision only supported on CUDA
  33. #
  34. # # Load model
  35. # model = attempt_load(weights, map_location=device) # load FP32 model
  36. # imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  37. # if half:
  38. # model.half() # to FP16
  39. #
  40. # # Second-stage classifier
  41. # # classify = False
  42. # # if classify:
  43. # # modelc = load_classifier(name='resnet101', n=2) # initialize
  44. # # modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
  45. # # modelc.to(device).eval()
  46. #
  47. # # Set Dataloader
  48. # vid_path, vid_writer = None, None
  49. # if webcam:
  50. # view_img = True
  51. # cudnn.benchmark = True # set True to speed up constant image size inference
  52. # dataset = LoadStreams(source, img_size=imgsz)
  53. # else:
  54. # save_img = True
  55. # view_img = False
  56. # dataset = LoadImages(source, img_size=imgsz)
  57. #
  58. # # Get names and colors
  59. # names = model.module.names if hasattr(model, 'module') else model.names
  60. # colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
  61. #
  62. # # Run inference
  63. # t0 = time.time()
  64. # img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  65. # _ = model(img.half() if half else img) if device.type != 'cpu' else None
  66. # list_file = open("detection.txt", 'w')# run once
  67. # for path, img, im0s, vid_cap in dataset:
  68. # img = torch.from_numpy(img).to(device)
  69. # img = img.half() if half else img.float() # uint8 to fp16/32
  70. # img /= 255.0 # 0 - 255 to 0.0 - 1.0
  71. # if img.ndimension() == 3:
  72. # img = img.unsqueeze(0)
  73. #
  74. # # Inference
  75. # t1 = time_synchronized()
  76. # pred = model(img, augment='store_true')[0]
  77. #
  78. # # Apply NMS
  79. # pred = non_max_suppression(pred, 0.4,0.5, agnostic='store_true')
  80. # t2 = time_synchronized()
  81. #
  82. # # # Apply Classifier
  83. # # if classify:
  84. # # pred = apply_classifier(pred, modelc, img, im0s)
  85. #
  86. # # Process detections
  87. # for i, det in enumerate(pred): # detections per image
  88. # if webcam: # batch_size >= 1
  89. # p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
  90. # else:
  91. # p, s, im0 = path, '', im0s
  92. #
  93. # save_path = str(Path(out) / Path(p).name)
  94. # txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
  95. # s += '%gx%g ' % img.shape[2:] # print string
  96. # gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  97. # if det is not None and len(det):
  98. # # Rescale boxes from img_size to im0 size
  99. # det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  100. #
  101. # # Print results
  102. # for c in det[:, -1].unique():
  103. # n = (det[:, -1] == c).sum() # detections per class
  104. # s += '%g %ss, ' % (n, names[int(c)]) # add to string
  105. #
  106. # # Write results
  107. #
  108. #
  109. # for *xyxy, conf, cls in reversed(det):
  110. # if save_txt: # Write to file
  111. # xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  112. # with open(txt_path + '.txt', 'a') as f:
  113. # f.write(('%g ' * 5 + '\n') % (cls, *xywh))
  114. #
  115. # # label format
  116. #
  117. # if save_img or view_img: # Add bbox to image
  118. # label = '%s %.2f' % (names[int(cls)], conf)
  119. # plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  120. #
  121. # # Print time (inference + NMS)
  122. # # with open(os.getcwd()+'output.txt','w') as f:
  123. # # f.write('%sDone. (%.3fs)' % (s, t2 - t1))
  124. #
  125. # list_file.write('%sDone. (%.3fs)' % (s, t2 - t1))
  126. # list_file.write('\n')
  127. # print('%sDone. (%.3fs)' % (s, t2 - t1))
  128. #
  129. # # Stream results
  130. # if view_img:
  131. # cv2.imshow(p, im0)
  132. # if cv2.waitKey(1) == ord('q'): # q to quit
  133. # raise StopIteration
  134. #
  135. # # Save results (image with detections)
  136. # if save_img:
  137. # if dataset.mode == 'images':
  138. # cv2.imwrite(save_path, im0)
  139. # else:
  140. # if vid_path != save_path: # new video
  141. # vid_path = save_path
  142. # if isinstance(vid_writer, cv2.VideoWriter):
  143. # vid_writer.release() # release previous video writer
  144. #
  145. # fourcc = 'mp4v' # output video codec
  146. # fps = vid_cap.get(cv2.CAP_PROP_FPS)
  147. # w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  148. # h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  149. # vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
  150. # vid_writer.write(im0)
  151. #
  152. # if save_txt or save_img:
  153. # print('Results saved to %s' % Path(out))
  154. # # if platform.system() == 'Darwin' and not opt.update: # MacOS
  155. # # os.system('open ' + save_path)
  156. #
  157. # print('Done. (%.3fs)' % (time.time() - t0))
  158. def detect(save_img=False):
  159. out, source, weights, view_img, save_txt, imgsz = \
  160. opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  161. webcam = source.isnumeric() or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')
  162. print('-----')
  163. print(source)
  164. print(type(source))
  165. # Initialize
  166. set_logging()
  167. device = select_device(opt.device)
  168. if os.path.exists(out):
  169. shutil.rmtree(out) # delete output folder
  170. os.makedirs(out) # make new output folder
  171. half = device.type != 'cpu' # half precision only supported on CUDA
  172. # Load model
  173. model = attempt_load(weights, map_location=device) # load FP32 model
  174. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  175. if half:
  176. model.half() # to FP16
  177. # Second-stage classifier
  178. classify = False
  179. if classify:
  180. modelc = load_classifier(name='resnet101', n=2) # initialize
  181. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
  182. modelc.to(device).eval()
  183. # Set Dataloader
  184. vid_path, vid_writer = None, None
  185. if webcam:
  186. view_img = True
  187. cudnn.benchmark = True # set True to speed up constant image size inference
  188. dataset = LoadStreams(source, img_size=imgsz)
  189. else:
  190. save_img = True
  191. dataset = LoadImages(source, img_size=imgsz)
  192. # Get names and colors
  193. names = model.module.names if hasattr(model, 'module') else model.names
  194. colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
  195. # Run inference
  196. t0 = time.time()
  197. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  198. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  199. for path, img, im0s, vid_cap in dataset:
  200. print('path:{0}'.format(path))
  201. print('im0s:{0}'.format(im0s))
  202. print('im0s类型:{0}'.format(type(im0s)))
  203. img = torch.from_numpy(img).to(device)
  204. img = img.half() if half else img.float() # uint8 to fp16/32
  205. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  206. if img.ndimension() == 3:
  207. img = img.unsqueeze(0)
  208. # Inference
  209. t1 = time_synchronized()
  210. pred = model(img, augment=opt.augment)[0]
  211. # Apply NMS
  212. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  213. t2 = time_synchronized()
  214. # Apply Classifier
  215. if classify:
  216. pred = apply_classifier(pred, modelc, img, im0s)
  217. # 用于存储人员边界坐标的列表 ---linjie
  218. people_coords = []
  219. # Process detections
  220. for i, det in enumerate(pred): # detections per image
  221. if webcam: # batch_size >= 1
  222. p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
  223. else:
  224. p, s, im0 = path, '', im0s
  225. save_path = str(Path(out) / Path(p).name)
  226. txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
  227. s += '%gx%g ' % img.shape[2:] # print string
  228. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  229. if det is not None and len(det):
  230. # print('先看看这里能不能进行,再看看im0多少:{0}。再看看im0类型:{1}'.format(im0,type(im0)))
  231. # Rescale boxes from img_size to im0 size
  232. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  233. # Print results
  234. for c in det[:, -1].unique():
  235. n = (det[:, -1] == c).sum() # detections per class
  236. s += '%g %ss, ' % (n, names[int(c)]) # add to string
  237. # Write results
  238. for *xyxy, conf, cls in reversed(det):
  239. if save_txt: # Write to file
  240. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  241. with open(txt_path + '.txt', 'a') as f:
  242. f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
  243. if save_img or view_img: # Add bbox to image
  244. label = '%s %.2f' % (names[int(cls)], conf)
  245. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  246. #判断标签是否为人 --linjie
  247. if label is not None:
  248. if (label.split())[0] == 'person':
  249. print('标签是人')
  250. distancing(people_coords, im0, dist_thres_lim=(200, 250))
  251. people_coords.append(xyxy)
  252. # plot_one_box(xyxy, im0, line_thickness=3)
  253. plot_dots_on_people(xyxy, im0)
  254. # 画上人与人的连接线 --linjie
  255. distancing(people_coords, im0, dist_thres_lim=(200, 250))
  256. # Print time (inference + NMS)
  257. print('%sDone. (%.3fs)' % (s, t2 - t1))
  258. # Stream results
  259. if view_img:
  260. cv2.imshow(p, im0)
  261. if cv2.waitKey(1) == ord('q'): # q to quit
  262. raise StopIteration
  263. # Save results (image with detections)
  264. if save_img:
  265. if dataset.mode == 'images':
  266. cv2.imwrite(save_path, im0)
  267. else:
  268. if vid_path != save_path: # new video
  269. vid_path = save_path
  270. if isinstance(vid_writer, cv2.VideoWriter):
  271. vid_writer.release() # release previous video writer
  272. fourcc = 'mp4v' # output video codec
  273. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  274. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  275. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  276. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
  277. vid_writer.write(im0)
  278. if save_txt or save_img:
  279. print('Results saved to %s' % Path(out))
  280. if platform == 'Darwin' and not opt.update: # MacOS
  281. os.system('open ' + save_path)
  282. print('Done. (%.3fs)' % (time.time() - t0))
  283. if __name__ == '__main__':
  284. parser = argparse.ArgumentParser()
  285. parser.add_argument('--weights', nargs='+', type=str, default='weights/best.pt', help='model.pt path(s)')
  286. parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder inference/images, 0 for webcam
  287. parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder
  288. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  289. parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold')
  290. parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
  291. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  292. parser.add_argument('--view-img', action='store_true', help='display results')
  293. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  294. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  295. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  296. parser.add_argument('--augment', action='store_true', help='augmented inference')
  297. parser.add_argument('--update', action='store_true', help='update all models')
  298. opt = parser.parse_args()
  299. print(opt)
  300. with torch.no_grad():
  301. if opt.update: # update all models (to fix SourceChangeWarning)
  302. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  303. print('model1')
  304. detect()
  305. strip_optimizer(opt.weights)
  306. else:
  307. print('model2')
  308. detect()

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。