diff --git a/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py index fc257e09..df0a185e 100644 --- a/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py @@ -6,6 +6,7 @@ from ...preprocessors import DialogStateTrackingPreprocessor from ...utils.constant import Tasks from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['DialogStateTrackingPipeline'] @@ -53,7 +54,7 @@ class DialogStateTrackingPipeline(Pipeline): _outputs[5], unique_ids, input_ids_unmasked, values, inform, prefix, ds) - return {'dialog_states': ds} + return {OutputKeys.DIALOG_STATES: ds} def predict_and_format(config, tokenizer, features, per_slot_class_logits, diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 8fcf498b..4126b538 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -20,6 +20,7 @@ class OutputKeys(object): TEXT_EMBEDDING = 'text_embedding' RESPONSE = 'response' PREDICTION = 'prediction' + DIALOG_STATES = 'dialog_states' TASK_OUTPUTS = { @@ -151,6 +152,7 @@ TASK_OUTPUTS = { # } Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS], + # dialog intent prediction result for single sample # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, @@ -174,16 +176,11 @@ TASK_OUTPUTS = { Tasks.dialog_intent_prediction: [OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL], + # dialog modeling prediction result for single sample # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] Tasks.dialog_modeling: [OutputKeys.RESPONSE], - # nli result for single sample - # { - # "labels": ["happy", "sad", "calm", "angry"], - # "scores": [0.9, 0.1, 0.05, 0.05] - # } - Tasks.nli: ['scores', 'labels'], - + # dialog state tracking result for single sample # { # "dialog_states": { # "taxi-leaveAt": "none", @@ -218,32 +215,7 @@ TASK_OUTPUTS = { # "train-departure": "none" # } # } - Tasks.dialog_state_tracking: ['dialog_states'], - - # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, - # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, - # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, - # 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, - # 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05, - # 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05, - # 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05, - # 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05, - # 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03, - # 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05, - # 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05, - # 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05, - # 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05, - # 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05, - # 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05, - # 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05, - # 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, - # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, - # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, - # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} - Tasks.dialog_intent_prediction: ['prediction', 'label_pos', 'label'], - - # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] - Tasks.dialog_modeling: ['response'], + Tasks.dialog_state_tracking: [OutputKeys.DIALOG_STATES], # ============ audio tasks ===================