Browse Source

intent success

master
ly119399 3 years ago
parent
commit
c6ec7f2fa4
5 changed files with 41 additions and 33 deletions
  1. +3
    -1
      maas_lib/models/nlp/space/dialog_intent_model.py
  2. +5
    -17
      maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py
  3. +3
    -1
      maas_lib/preprocessors/space/dialog_intent_preprocessor.py
  4. +22
    -0
      maas_lib/trainers/nlp/space/trainers/intent_trainer.py
  5. +8
    -14
      tests/pipelines/nlp/test_dialog_intent.py

+ 3
- 1
maas_lib/models/nlp/space/dialog_intent_model.py View File

@@ -65,5 +65,7 @@ class DialogIntentModel(Model):
"""
from numpy import array, float32
import torch
print('--forward--')
result = self.trainer.forward(input)

return {}
return result

+ 5
- 17
maas_lib/pipelines/nlp/space/dialog_intent_pipeline.py View File

@@ -3,14 +3,14 @@ from typing import Any, Dict, Optional
from maas_lib.models.nlp import DialogIntentModel
from maas_lib.preprocessors import DialogIntentPreprocessor
from maas_lib.utils.constant import Tasks
from ...base import Model, Tensor
from ...base import Input, Pipeline
from ...builder import PIPELINES

__all__ = ['DialogIntentPipeline']


@PIPELINES.register_module(Tasks.dialog_intent, module_name=r'space-intent')
class DialogIntentPipeline(Model):
class DialogIntentPipeline(Pipeline):

def __init__(self, model: DialogIntentModel,
preprocessor: DialogIntentPreprocessor, **kwargs):
@@ -23,9 +23,9 @@ class DialogIntentPipeline(Model):

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model = model
self.tokenizer = preprocessor.tokenizer
# self.tokenizer = preprocessor.tokenizer

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results

Args:
@@ -35,16 +35,4 @@ class DialogIntentPipeline(Model):
Dict[str, str]: the prediction results
"""

vocab_size = len(self.tokenizer.vocab)
pred_list = inputs['predictions']
pred_ids = pred_list[0][0].cpu().numpy().tolist()
for j in range(len(pred_ids)):
if pred_ids[j] >= vocab_size:
pred_ids[j] = 100
pred = self.tokenizer.convert_ids_to_tokens(pred_ids)
pred_string = ''.join(pred).replace(
'##',
'').split('[SEP]')[0].replace('[CLS]',
'').replace('[SEP]',
'').replace('[UNK]', '')
return {'pred_string': pred_string}
return inputs

+ 3
- 1
maas_lib/preprocessors/space/dialog_intent_preprocessor.py View File

@@ -43,5 +43,7 @@ class DialogIntentPreprocessor(Preprocessor):
Returns:
Dict[str, Any]: the preprocessed data
"""
samples = self.text_field.preprocessor([data])
samples, _ = self.text_field.collate_fn_multi_turn(samples)

return self.text_field.preprocessor(data)
return samples

+ 22
- 0
maas_lib/trainers/nlp/space/trainers/intent_trainer.py View File

@@ -506,6 +506,28 @@ class IntentTrainer(Trainer):
self.save_and_log_message(
report_for_unlabeled_data, cur_valid_metric=-accuracy)

def forward(self, batch):
outputs, labels = [], []
pred, true = [], []

with torch.no_grad():
batch = type(batch)(
map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
result = self.model.infer(inputs=batch)
result = {
name: result[name].cpu().detach().numpy()
for name in result
}
intent_probs = result['intent_probs']
if self.can_norm:
pred += [intent_probs]
true += batch['intent_label'].cpu().detach().tolist()
else:
pred += np.argmax(intent_probs, axis=1).tolist()
true += batch['intent_label'].cpu().detach().tolist()

return {'pred': pred}

def infer(self, data_iter, num_batches=None, ex_data_iter=None):
"""
Inference interface.


+ 8
- 14
tests/pipelines/nlp/test_dialog_intent.py View File

@@ -4,11 +4,12 @@ import os.path as osp
import tempfile
import unittest

from tests.case.nlp.dialog_generation_case import test_case
from tests.case.nlp.dialog_intent_case import test_case

from maas_lib.models.nlp import DialogIntentModel
from maas_lib.pipelines import DialogIntentPipeline, pipeline
from maas_lib.preprocessors import DialogIntentPreprocessor
from maas_lib.utils.constant import Tasks


class DialogGenerationTest(unittest.TestCase):
@@ -22,19 +23,12 @@ class DialogGenerationTest(unittest.TestCase):
model_dir=modeldir,
text_field=preprocessor.text_field,
config=preprocessor.config)
print(model.forward(None))
# pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor)
#
# history_dialog_info = {}
# for step, item in enumerate(test_case['sng0073']['log']):
# user_question = item['user']
# print('user: {}'.format(user_question))
#
# # history_dialog_info = merge(history_dialog_info,
# # result) if step > 0 else {}
# result = pipeline(user_question, history=history_dialog_info)
# #
# # print('sys : {}'.format(result['pred_answer']))
pipeline1 = DialogIntentPipeline(
model=model, preprocessor=preprocessor)
# pipeline1 = pipeline(task=Tasks.dialog_intent, model=model, preprocessor=preprocessor)

for item in test_case:
pipeline1(item)


if __name__ == '__main__':


Loading…
Cancel
Save