|
|
@@ -1,10 +1,13 @@ |
|
|
|
import os.path as osp |
|
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
import cv2 |
|
|
|
import numpy as np |
|
|
|
import PIL |
|
|
|
import torch |
|
|
|
|
|
|
|
from modelscope.metainfo import Pipelines |
|
|
|
from modelscope.models.cv.face_detection.retinaface import detection |
|
|
|
from modelscope.models.cv.face_detection import RetinaFaceDetection |
|
|
|
from modelscope.outputs import OutputKeys |
|
|
|
from modelscope.pipelines.base import Input, Pipeline |
|
|
|
from modelscope.pipelines.builder import PIPELINES |
|
|
@@ -28,7 +31,7 @@ class RetinaFaceDetectionPipeline(Pipeline): |
|
|
|
super().__init__(model=model, **kwargs) |
|
|
|
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) |
|
|
|
logger.info(f'loading model from {ckpt_path}') |
|
|
|
detector = detection.RetinaFaceDetection( |
|
|
|
detector = RetinaFaceDetection( |
|
|
|
model_path=ckpt_path, device=self.device) |
|
|
|
self.detector = detector |
|
|
|
logger.info('load model done') |
|
|
|