You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

track.py 11 kB

3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Author : linjie
  4. import sys
  5. sys.path.insert(0, './yolov5')
  6. from utils.datasets import LoadImages, LoadStreams
  7. from utils.general import check_img_size, non_max_suppression, scale_coords
  8. from utils.torch_utils import select_device, time_synchronized
  9. from deep_sort_pytorch.utils.parser import get_config
  10. from deep_sort_pytorch.deep_sort import DeepSort
  11. import argparse
  12. import os
  13. import platform
  14. import shutil
  15. import time
  16. from pathlib import Path
  17. import cv2
  18. import torch
  19. import torch.backends.cudnn as cudnn
  20. from utils.general import (
  21. check_img_size, non_max_suppression, apply_classifier, scale_coords,
  22. xyxy2xywh, plot_one_box, strip_optimizer, set_logging)
  23. palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
  24. def bbox_rel(*xyxy):
  25. """" Calculates the relative bounding box from absolute pixel values. """
  26. bbox_left = min([xyxy[0].item(), xyxy[2].item()])
  27. bbox_top = min([xyxy[1].item(), xyxy[3].item()])
  28. bbox_w = abs(xyxy[0].item() - xyxy[2].item())
  29. bbox_h = abs(xyxy[1].item() - xyxy[3].item())
  30. x_c = (bbox_left + bbox_w / 2)
  31. y_c = (bbox_top + bbox_h / 2)
  32. w = bbox_w
  33. h = bbox_h
  34. return x_c, y_c, w, h
  35. def compute_color_for_labels(label):
  36. """
  37. Simple function that adds fixed color depending on the class
  38. """
  39. color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
  40. return tuple(color)
  41. past_identities = []
  42. def draw_boxes(img, bbox, cls_names, scores, identities=None, offset=(0, 0)):
  43. for i, box in enumerate(bbox):
  44. x1, y1, x2, y2 = [int(i) for i in box]
  45. x1 += offset[0]
  46. x2 += offset[0]
  47. y1 += offset[1]
  48. y2 += offset[1]
  49. # box text and bar
  50. id = int(identities[i]) if identities is not None else 0
  51. if int(1) not in identities:
  52. print('===---===-------没有1============================')
  53. print(identities)
  54. color = compute_color_for_labels(id)
  55. label = '%d %s %d' % (id, cls_names[i], scores[i])
  56. label += '%'
  57. print("{0}号人物出现!========================================".format(id))
  58. t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
  59. cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
  60. cv2.rectangle(
  61. img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
  62. cv2.putText(img, label, (x1, y1 +
  63. t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
  64. return img
  65. def detect(opt, save_img=False):
  66. out, source, weights, view_img, save_txt, imgsz = \
  67. opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  68. webcam = source == '0' or source.startswith(
  69. 'rtsp') or source.startswith('http') or source.endswith('.txt')
  70. # initialize deepsort
  71. cfg = get_config()
  72. cfg.merge_from_file(opt.config_deepsort)
  73. deepsort = DeepSort(cfg.DEEPSORT.REID_CKPT,
  74. max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
  75. nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
  76. max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
  77. use_cuda=True)
  78. # Initialize
  79. device = select_device(opt.device)
  80. if os.path.exists(out):
  81. shutil.rmtree(out) # delete output folder
  82. os.makedirs(out) # make new output folder
  83. half = device.type != 'cpu' # half precision only supported on CUDA
  84. # Load model
  85. model = torch.load(weights, map_location=device)[
  86. 'model'].float() # load to FP32
  87. model.to(device).eval()
  88. if half:
  89. model.half() # to FP16
  90. # Set Dataloader
  91. vid_path, vid_writer = None, None
  92. if webcam:
  93. view_img = True
  94. cudnn.benchmark = True # set True to speed up constant image size inference
  95. dataset = LoadStreams(source, img_size=imgsz)
  96. else:
  97. view_img = True
  98. save_img = True
  99. dataset = LoadImages(source, img_size=imgsz)
  100. # Get names and colors
  101. names = model.module.names if hasattr(model, 'module') else model.names
  102. # Run inference
  103. t0 = time.time()
  104. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  105. # run once
  106. _ = model(img.half() if half else img) if device.type != 'cpu' else None
  107. save_path = str(Path(out))
  108. txt_path = str(Path(out)) + '/results.txt'
  109. for frame_idx, (path, img, im0s, vid_cap) in enumerate(dataset):
  110. img = torch.from_numpy(img).to(device)
  111. img = img.half() if half else img.float() # uint8 to fp16/32
  112. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  113. if img.ndimension() == 3:
  114. img = img.unsqueeze(0)
  115. # Inference
  116. t1 = time_synchronized()
  117. pred = model(img, augment=opt.augment)[0]
  118. # Apply NMS
  119. pred = non_max_suppression(
  120. pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  121. t2 = time_synchronized()
  122. # Process detections
  123. for i, det in enumerate(pred): # detections per image
  124. if webcam: # batch_size >= 1
  125. p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
  126. else:
  127. p, s, im0 = path, '', im0s
  128. s += '%gx%g ' % img.shape[2:] # print string
  129. save_path = str(Path(out) / Path(p).name)
  130. if det is not None and len(det):
  131. # Rescale boxes from img_size to im0 size
  132. det[:, :4] = scale_coords(
  133. img.shape[2:], det[:, :4], im0.shape).round()
  134. # Print results
  135. for c in det[:, -1].unique():
  136. n = (det[:, -1] == c).sum() # detections per class
  137. s += '%g %ss, ' % (n, names[int(c)]) # add to string
  138. bbox_xywh = []
  139. confs = []
  140. clses = []
  141. # Adapt detections to deep sort input format
  142. for *xyxy, conf, cls in det:
  143. x_c, y_c, bbox_w, bbox_h = bbox_rel(*xyxy)
  144. obj = [x_c, y_c, bbox_w, bbox_h]
  145. bbox_xywh.append(obj)
  146. confs.append([conf.item()])
  147. clses.append([cls.item()])
  148. xywhs = torch.Tensor(bbox_xywh)
  149. confss = torch.Tensor(confs)
  150. clses = torch.Tensor(clses)
  151. outputs = deepsort.update(xywhs, confss, clses, im0)
  152. # draw boxes for visualization
  153. if len(outputs) > 0:
  154. bbox_tlwh = []
  155. bbox_xyxy = outputs[:, :4]
  156. identities = outputs[:, 4]
  157. clses = outputs[:, 5]
  158. scores = outputs[:, 6]
  159. stays = outputs[:, 7]
  160. draw_boxes(im0, bbox_xyxy, [names[i] for i in clses], scores, identities)
  161. # Write MOT compliant results to file
  162. if save_txt and len(outputs) != 0:
  163. for j, output in enumerate(outputs):
  164. bbox_left = output[0]
  165. bbox_top = output[1]
  166. bbox_w = output[2]
  167. bbox_h = output[3]
  168. identity = output[-1]
  169. with open(txt_path, 'a') as f:
  170. f.write(('%g ' * 10 + '\n') % (frame_idx, identity, bbox_left,
  171. bbox_top, bbox_w, bbox_h, -1, -1, -1, -1)) # label format
  172. else:
  173. deepsort.increment_ages()
  174. # Print time (inference + NMS)
  175. print('%sDone. (%.3fs)' % (s, t2 - t1))
  176. # Stream results
  177. if view_img:
  178. cv2.imshow(p, im0)
  179. if cv2.waitKey(1) == ord('q'): # q to quit
  180. raise StopIteration
  181. # Save results (image with detections)
  182. if save_img:
  183. print('saving img!')
  184. if dataset.mode == 'images':
  185. cv2.imwrite(save_path, im0)
  186. else:
  187. print('saving video!')
  188. if vid_path != save_path: # new video
  189. vid_path = save_path
  190. if isinstance(vid_writer, cv2.VideoWriter):
  191. vid_writer.release() # release previous video writer
  192. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  193. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  194. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  195. vid_writer = cv2.VideoWriter(
  196. save_path, cv2.VideoWriter_fourcc(*opt.fourcc), fps, (w, h))
  197. vid_writer.write(im0)
  198. if save_txt or save_img:
  199. print('Results saved to %s' % os.getcwd() + os.sep + out)
  200. if platform == 'darwin': # MacOS
  201. os.system('open ' + save_path)
  202. print('Done. (%.3fs)' % (time.time() - t0))
  203. if __name__ == '__main__':
  204. parser = argparse.ArgumentParser()
  205. parser.add_argument('--weights', type=str,
  206. default='weights/yolov5s.pt', help='model.pt path')
  207. # file/folder, 0 for webcam
  208. parser.add_argument('--source', type=str,
  209. default='inference/images', help='source')
  210. parser.add_argument('--output', type=str, default='inference/output',
  211. help='output folder') # output folder
  212. parser.add_argument('--img-size', type=int, default=640,
  213. help='inference size (pixels)')
  214. parser.add_argument('--conf-thres', type=float,
  215. default=0.4, help='object confidence threshold')
  216. parser.add_argument('--iou-thres', type=float,
  217. default=0.5, help='IOU threshold for NMS')
  218. parser.add_argument('--fourcc', type=str, default='mp4v',
  219. help='output video codec (verify ffmpeg support)')
  220. parser.add_argument('--device', default='',
  221. help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  222. parser.add_argument('--view-img', action='store_true',
  223. help='display results')
  224. parser.add_argument('--save-txt', action='store_true',
  225. help='save results to *.txt')
  226. # class 0 is person
  227. parser.add_argument('--classes', nargs='+', type=int,
  228. default=[0], help='filter by class')
  229. parser.add_argument('--agnostic-nms', action='store_true',
  230. help='class-agnostic NMS')
  231. parser.add_argument('--augment', action='store_true',
  232. help='augmented inference')
  233. parser.add_argument("--config_deepsort", type=str,
  234. default="deep_sort_pytorch/configs/deep_sort.yaml")
  235. args = parser.parse_args()
  236. args.img_size = check_img_size(args.img_size)
  237. print(args)
  238. with torch.no_grad():
  239. detect(args)

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。