|
|
@@ -45,6 +45,9 @@ class FacialExpressionRecognitionPipeline(Pipeline): |
|
|
|
|
|
|
|
# face detect pipeline |
|
|
|
det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' |
|
|
|
self.map_list = [ |
|
|
|
'Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral' |
|
|
|
] |
|
|
|
self.face_detection = pipeline( |
|
|
|
Tasks.face_detection, model=det_model_id) |
|
|
|
|
|
|
@@ -122,7 +125,7 @@ class FacialExpressionRecognitionPipeline(Pipeline): |
|
|
|
labels = result[1].tolist() |
|
|
|
return { |
|
|
|
OutputKeys.SCORES: scores, |
|
|
|
OutputKeys.LABELS: labels, |
|
|
|
OutputKeys.LABELS: self.map_list[labels] |
|
|
|
} |
|
|
|
|
|
|
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|