|
|
@@ -3,14 +3,14 @@ from typing import Any, Dict, Optional |
|
|
|
from maas_lib.models.nlp import DialogIntentModel |
|
|
|
from maas_lib.preprocessors import DialogIntentPreprocessor |
|
|
|
from maas_lib.utils.constant import Tasks |
|
|
|
from ...base import Model, Tensor |
|
|
|
from ...base import Input, Pipeline |
|
|
|
from ...builder import PIPELINES |
|
|
|
|
|
|
|
__all__ = ['DialogIntentPipeline'] |
|
|
|
|
|
|
|
|
|
|
|
@PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent') |
|
|
|
class DialogIntentPipeline(Model): |
|
|
|
class DialogIntentPipeline(Pipeline): |
|
|
|
|
|
|
|
def __init__(self, model: DialogIntentModel, |
|
|
|
preprocessor: DialogIntentPreprocessor, **kwargs): |
|
|
@@ -23,9 +23,9 @@ class DialogIntentPipeline(Model): |
|
|
|
|
|
|
|
super().__init__(model=model, preprocessor=preprocessor, **kwargs) |
|
|
|
self.model = model |
|
|
|
self.tokenizer = preprocessor.tokenizer |
|
|
|
# self.tokenizer = preprocessor.tokenizer |
|
|
|
|
|
|
|
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: |
|
|
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: |
|
|
|
"""process the prediction results |
|
|
|
|
|
|
|
Args: |
|
|
@@ -35,16 +35,4 @@ class DialogIntentPipeline(Model): |
|
|
|
Dict[str, str]: the prediction results |
|
|
|
""" |
|
|
|
|
|
|
|
vocab_size = len(self.tokenizer.vocab) |
|
|
|
pred_list = inputs['predictions'] |
|
|
|
pred_ids = pred_list[0][0].cpu().numpy().tolist() |
|
|
|
for j in range(len(pred_ids)): |
|
|
|
if pred_ids[j] >= vocab_size: |
|
|
|
pred_ids[j] = 100 |
|
|
|
pred = self.tokenizer.convert_ids_to_tokens(pred_ids) |
|
|
|
pred_string = ''.join(pred).replace( |
|
|
|
'##', |
|
|
|
'').split('[SEP]')[0].replace('[CLS]', |
|
|
|
'').replace('[SEP]', |
|
|
|
'').replace('[UNK]', '') |
|
|
|
return {'pred_string': pred_string} |
|
|
|
return inputs |