|
|
@@ -79,12 +79,9 @@ class TextClassificationPipeline(Pipeline): |
|
|
|
'sequence_length': sequence_length, |
|
|
|
**kwargs |
|
|
|
}) |
|
|
|
assert hasattr(self.preprocessor, 'id2label') |
|
|
|
self.id2label = self.preprocessor.id2label |
|
|
|
if self.id2label is None: |
|
|
|
logger.warn( |
|
|
|
'The id2label mapping is None, will return original ids.' |
|
|
|
) |
|
|
|
|
|
|
|
if hasattr(self.preprocessor, 'id2label'): |
|
|
|
self.id2label = self.preprocessor.id2label |
|
|
|
|
|
|
|
def forward(self, inputs: Dict[str, Any], |
|
|
|
**forward_params) -> Dict[str, Any]: |
|
|
@@ -111,6 +108,9 @@ class TextClassificationPipeline(Pipeline): |
|
|
|
if self.model.__class__.__name__ == 'OfaForAllTasks': |
|
|
|
return inputs |
|
|
|
else: |
|
|
|
if getattr(self, 'id2label', None) is None: |
|
|
|
logger.warn( |
|
|
|
'The id2label mapping is None, will return original ids.') |
|
|
|
logits = inputs[OutputKeys.LOGITS].cpu().numpy() |
|
|
|
if logits.shape[0] == 1: |
|
|
|
logits = logits[0] |
|
|
@@ -126,7 +126,7 @@ class TextClassificationPipeline(Pipeline): |
|
|
|
probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() |
|
|
|
|
|
|
|
def map_to_label(id): |
|
|
|
if self.id2label is not None: |
|
|
|
if getattr(self, 'id2label', None) is not None: |
|
|
|
if id in self.id2label: |
|
|
|
return self.id2label[id] |
|
|
|
elif str(id) in self.id2label: |
|
|
|