@@ -65,5 +65,7 @@ class DialogIntentModel(Model): | |||||
""" | """ | ||||
from numpy import array, float32 | from numpy import array, float32 | ||||
import torch | import torch | ||||
print('--forward--') | |||||
result = self.trainer.forward(input) | |||||
return {} | |||||
return result |
@@ -3,14 +3,14 @@ from typing import Any, Dict, Optional | |||||
from maas_lib.models.nlp import DialogIntentModel | from maas_lib.models.nlp import DialogIntentModel | ||||
from maas_lib.preprocessors import DialogIntentPreprocessor | from maas_lib.preprocessors import DialogIntentPreprocessor | ||||
from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
from ...base import Model, Tensor | |||||
from ...base import Input, Pipeline | |||||
from ...builder import PIPELINES | from ...builder import PIPELINES | ||||
__all__ = ['DialogIntentPipeline'] | __all__ = ['DialogIntentPipeline'] | ||||
@PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent') | @PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent') | ||||
class DialogIntentPipeline(Model): | |||||
class DialogIntentPipeline(Pipeline): | |||||
def __init__(self, model: DialogIntentModel, | def __init__(self, model: DialogIntentModel, | ||||
preprocessor: DialogIntentPreprocessor, **kwargs): | preprocessor: DialogIntentPreprocessor, **kwargs): | ||||
@@ -23,9 +23,9 @@ class DialogIntentPipeline(Model): | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
self.model = model | self.model = model | ||||
self.tokenizer = preprocessor.tokenizer | |||||
# self.tokenizer = preprocessor.tokenizer | |||||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
"""process the prediction results | """process the prediction results | ||||
Args: | Args: | ||||
@@ -35,16 +35,4 @@ class DialogIntentPipeline(Model): | |||||
Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
""" | """ | ||||
vocab_size = len(self.tokenizer.vocab) | |||||
pred_list = inputs['predictions'] | |||||
pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||||
for j in range(len(pred_ids)): | |||||
if pred_ids[j] >= vocab_size: | |||||
pred_ids[j] = 100 | |||||
pred = self.tokenizer.convert_ids_to_tokens(pred_ids) | |||||
pred_string = ''.join(pred).replace( | |||||
'##', | |||||
'').split('[SEP]')[0].replace('[CLS]', | |||||
'').replace('[SEP]', | |||||
'').replace('[UNK]', '') | |||||
return {'pred_string': pred_string} | |||||
return inputs |
@@ -43,5 +43,7 @@ class DialogIntentPreprocessor(Preprocessor): | |||||
Returns: | Returns: | ||||
Dict[str, Any]: the preprocessed data | Dict[str, Any]: the preprocessed data | ||||
""" | """ | ||||
samples = self.text_field.preprocessor([data]) | |||||
samples, _ = self.text_field.collate_fn_multi_turn(samples) | |||||
return self.text_field.preprocessor(data) | |||||
return samples |
@@ -506,6 +506,28 @@ class IntentTrainer(Trainer): | |||||
self.save_and_log_message( | self.save_and_log_message( | ||||
report_for_unlabeled_data, cur_valid_metric=-accuracy) | report_for_unlabeled_data, cur_valid_metric=-accuracy) | ||||
def forward(self, batch): | |||||
outputs, labels = [], [] | |||||
pred, true = [], [] | |||||
with torch.no_grad(): | |||||
batch = type(batch)( | |||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) | |||||
result = self.model.infer(inputs=batch) | |||||
result = { | |||||
name: result[name].cpu().detach().numpy() | |||||
for name in result | |||||
} | |||||
intent_probs = result['intent_probs'] | |||||
if self.can_norm: | |||||
pred += [intent_probs] | |||||
true += batch['intent_label'].cpu().detach().tolist() | |||||
else: | |||||
pred += np.argmax(intent_probs, axis=1).tolist() | |||||
true += batch['intent_label'].cpu().detach().tolist() | |||||
return {'pred': pred} | |||||
def infer(self, data_iter, num_batches=None, ex_data_iter=None): | def infer(self, data_iter, num_batches=None, ex_data_iter=None): | ||||
""" | """ | ||||
Inference interface. | Inference interface. | ||||
@@ -4,11 +4,12 @@ import os.path as osp | |||||
import tempfile | import tempfile | ||||
import unittest | import unittest | ||||
from tests.case.nlp.dialog_generation_case import test_case | |||||
from tests.case.nlp.dialog_intent_case import test_case | |||||
from maas_lib.models.nlp import DialogIntentModel | from maas_lib.models.nlp import DialogIntentModel | ||||
from maas_lib.pipelines import DialogIntentPipeline, pipeline | from maas_lib.pipelines import DialogIntentPipeline, pipeline | ||||
from maas_lib.preprocessors import DialogIntentPreprocessor | from maas_lib.preprocessors import DialogIntentPreprocessor | ||||
from maas_lib.utils.constant import Tasks | |||||
class DialogGenerationTest(unittest.TestCase): | class DialogGenerationTest(unittest.TestCase): | ||||
@@ -22,19 +23,12 @@ class DialogGenerationTest(unittest.TestCase): | |||||
model_dir=modeldir, | model_dir=modeldir, | ||||
text_field=preprocessor.text_field, | text_field=preprocessor.text_field, | ||||
config=preprocessor.config) | config=preprocessor.config) | ||||
print(model.forward(None)) | |||||
# pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor) | |||||
# | |||||
# history_dialog_info = {} | |||||
# for step, item in enumerate(test_case['sng0073']['log']): | |||||
# user_question = item['user'] | |||||
# print('user: {}'.format(user_question)) | |||||
# | |||||
# # history_dialog_info = merge(history_dialog_info, | |||||
# # result) if step > 0 else {} | |||||
# result = pipeline(user_question, history=history_dialog_info) | |||||
# # | |||||
# # print('sys : {}'.format(result['pred_answer'])) | |||||
pipeline1 = DialogIntentPipeline( | |||||
model=model, preprocessor=preprocessor) | |||||
# pipeline1 = pipeline(task=Tasks.dialog_intent, model=model, preprocessor=preprocessor) | |||||
for item in test_case: | |||||
pipeline1(item) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||