You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

detect.py 8.8 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from numpy import random
  8. from models.experimental import attempt_load
  9. from utils.datasets import LoadStreams, LoadImages
  10. from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
  11. scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
  12. from utils.plots import plot_one_box
  13. from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel
  14. def detect(save_img=False):
  15. source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace
  16. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
  17. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  18. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  19. # Directories
  20. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  21. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  22. # Initialize
  23. set_logging()
  24. device = select_device(opt.device)
  25. half = device.type != 'cpu' # half precision only supported on CUDA
  26. # Load model
  27. model = attempt_load(weights, map_location=device) # load FP32 model
  28. stride = int(model.stride.max()) # model stride
  29. imgsz = check_img_size(imgsz, s=stride) # check img_size
  30. if trace:
  31. model = TracedModel(model, device, opt.img_size)
  32. if half:
  33. model.half() # to FP16
  34. # Second-stage classifier
  35. classify = False
  36. if classify:
  37. modelc = load_classifier(name='resnet101', n=2) # initialize
  38. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  39. # Set Dataloader
  40. vid_path, vid_writer = None, None
  41. if webcam:
  42. view_img = check_imshow()
  43. cudnn.benchmark = True # set True to speed up constant image size inference
  44. dataset = LoadStreams(source, img_size=imgsz, stride=stride)
  45. else:
  46. dataset = LoadImages(source, img_size=imgsz, stride=stride)
  47. # Get names and colors
  48. names = model.module.names if hasattr(model, 'module') else model.names
  49. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
  50. # Run inference
  51. if device.type != 'cpu':
  52. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
  53. t0 = time.time()
  54. for path, img, im0s, vid_cap in dataset:
  55. img = torch.from_numpy(img).to(device)
  56. img = img.half() if half else img.float() # uint8 to fp16/32
  57. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  58. if img.ndimension() == 3:
  59. img = img.unsqueeze(0)
  60. # Inference
  61. t1 = time_synchronized()
  62. pred = model(img, augment=opt.augment)[0]
  63. # Apply NMS
  64. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  65. t2 = time_synchronized()
  66. # Apply Classifier
  67. if classify:
  68. pred = apply_classifier(pred, modelc, img, im0s)
  69. # Process detections
  70. for i, det in enumerate(pred): # detections per image
  71. if webcam: # batch_size >= 1
  72. p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
  73. else:
  74. p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
  75. p = Path(p) # to Path
  76. save_path = str(save_dir / p.name) # img.jpg
  77. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  78. s += '%gx%g ' % img.shape[2:] # print string
  79. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  80. if len(det):
  81. # Rescale boxes from img_size to im0 size
  82. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  83. # Print results
  84. for c in det[:, -1].unique():
  85. n = (det[:, -1] == c).sum() # detections per class
  86. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  87. # Write results
  88. for *xyxy, conf, cls in reversed(det):
  89. if save_txt: # Write to file
  90. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  91. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  92. with open(txt_path + '.txt', 'a') as f:
  93. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  94. if save_img or view_img: # Add bbox to image
  95. label = f'{names[int(cls)]} {conf:.2f}'
  96. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  97. # Print time (inference + NMS)
  98. #print(f'{s}Done. ({t2 - t1:.3f}s)')
  99. # Stream results
  100. if view_img:
  101. cv2.imshow(str(p), im0)
  102. cv2.waitKey(1) # 1 millisecond
  103. # Save results (image with detections)
  104. if save_img:
  105. if dataset.mode == 'image':
  106. cv2.imwrite(save_path, im0)
  107. print(f" The image with the result is saved in: {save_path}")
  108. else: # 'video' or 'stream'
  109. if vid_path != save_path: # new video
  110. vid_path = save_path
  111. if isinstance(vid_writer, cv2.VideoWriter):
  112. vid_writer.release() # release previous video writer
  113. if vid_cap: # video
  114. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  115. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  116. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  117. else: # stream
  118. fps, w, h = 30, im0.shape[1], im0.shape[0]
  119. save_path += '.mp4'
  120. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  121. vid_writer.write(im0)
  122. if save_txt or save_img:
  123. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  124. #print(f"Results saved to {save_dir}{s}")
  125. print(f'Done. ({time.time() - t0:.3f}s)')
  126. if __name__ == '__main__':
  127. parser = argparse.ArgumentParser()
  128. parser.add_argument('--weights', nargs='+', type=str, default='yolov7.pt', help='model.pt path(s)')
  129. parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder--'inference/images' , 0 for webcam
  130. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  131. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  132. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  133. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  134. parser.add_argument('--view-img', action='store_true', help='display results')
  135. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  136. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  137. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  138. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  139. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  140. parser.add_argument('--augment', action='store_true', help='augmented inference')
  141. parser.add_argument('--update', action='store_true', help='update all models')
  142. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  143. parser.add_argument('--name', default='exp', help='save results to project/name')
  144. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  145. parser.add_argument('--no-trace', action='store_true', help='don`t trace model')
  146. opt = parser.parse_args()
  147. print(opt)
  148. #check_requirements(exclude=('pycocotools', 'thop'))
  149. with torch.no_grad():
  150. if opt.update: # update all models (to fix SourceChangeWarning)
  151. for opt.weights in ['yolov7.pt']:
  152. detect()
  153. strip_optimizer(opt.weights)
  154. else:
  155. detect()

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。