Browse Source

add test cases

master
智丞 3 years ago
parent
commit
154c61fc25
11 changed files with 152 additions and 106 deletions
  1. +6
    -5
      modelscope/models/nlp/space/dialog_intent_prediction_model.py
  2. +7
    -5
      modelscope/models/nlp/space/dialog_modeling_model.py
  3. +12
    -11
      modelscope/models/nlp/space/dialog_state_tracking_model.py
  4. +2
    -0
      modelscope/pipelines/builder.py
  5. +14
    -8
      modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py
  6. +14
    -8
      modelscope/pipelines/nlp/dialog_modeling_pipeline.py
  7. +17
    -10
      modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py
  8. +0
    -7
      modelscope/pipelines/outputs.py
  9. +15
    -1
      tests/pipelines/test_dialog_intent_prediction.py
  10. +33
    -21
      tests/pipelines/test_dialog_modeling.py
  11. +32
    -30
      tests/pipelines/test_dialog_state_tracking.py

+ 6
- 5
modelscope/models/nlp/space/dialog_intent_prediction_model.py View File

@@ -63,15 +63,16 @@ class SpaceForDialogIntent(Model):
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Dict[str, Tensor]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05
1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04
6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01
2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32)
}
"""
import numpy as np


+ 7
- 5
modelscope/models/nlp/space/dialog_modeling_model.py View File

@@ -62,15 +62,17 @@ class SpaceForDialogModeling(Model):
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Dict[str, Tensor]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
'labels': array([1,192,321,12]), # lable
'resp': array([293,1023,123,1123]), #vocab label for response
'bspn': array([123,321,2,24,1 ]),
'aspn': array([47,8345,32,29,1983]),
'db': array([19, 24, 20]),
}
"""



+ 12
- 11
modelscope/models/nlp/space/dialog_state_tracking_model.py View File

@@ -2,6 +2,7 @@ import os
from typing import Any, Dict

from modelscope.utils.constant import Tasks
from ....metainfo import Models
from ....utils.nlp.space.utils_dst import batch_to_device
from ...base import Model, Tensor
from ...builder import MODELS
@@ -9,7 +10,7 @@ from ...builder import MODELS
__all__ = ['SpaceForDialogStateTracking']


@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space')
@MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space)
class SpaceForDialogStateTracking(Model):

def __init__(self, model_dir: str, *args, **kwargs):
@@ -17,8 +18,6 @@ class SpaceForDialogStateTracking(Model):

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""

super().__init__(model_dir, *args, **kwargs)
@@ -27,7 +26,6 @@ class SpaceForDialogStateTracking(Model):
self.model_dir = model_dir

self.config = SpaceConfig.from_pretrained(self.model_dir)
# self.model = SpaceForDST(self.config)
self.model = SpaceForDST.from_pretrained(self.model_dir)
self.model.to(self.config.device)

@@ -35,15 +33,20 @@ class SpaceForDialogStateTracking(Model):
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Dict[str, Tensor]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
'inputs': dict(input_ids, input_masks,start_pos), # tracking states
'outputs': dict(slots_logits),
'unique_ids': str(test-example.json-0), # default value
'input_ids_unmasked': array([101, 7632, 1010,0,0,0])
'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]),
'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]),
'prefix': str('final'), #default value
'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}])
}
"""
import numpy as np
@@ -88,8 +91,6 @@ class SpaceForDialogStateTracking(Model):
if u != 0:
diag_state[slot][i] = u

# print(outputs)

return {
'inputs': inputs,
'outputs': outputs,


+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -41,6 +41,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/nlp_space_dialog-intent-prediction'),
Tasks.dialog_modeling: (Pipelines.dialog_modeling,
'damo/nlp_space_dialog-modeling'),
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
'damo/nlp_space_dialog-state-tracking'),
Tasks.image_captioning: (Pipelines.image_caption,
'damo/ofa_image-caption_coco_large_en'),
Tasks.image_generation:


+ 14
- 8
modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py View File

@@ -1,8 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict
from typing import Any, Dict, Union

from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SpaceForDialogIntent
from ...preprocessors import DialogIntentPredictionPreprocessor
from ...utils.constant import Tasks
@@ -18,17 +19,22 @@ __all__ = ['DialogIntentPredictionPipeline']
module_name=Pipelines.dialog_intent_prediction)
class DialogIntentPredictionPipeline(Pipeline):

def __init__(self, model: SpaceForDialogIntent,
preprocessor: DialogIntentPredictionPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
def __init__(self,
model: Union[SpaceForDialogIntent, str],
preprocessor: DialogIntentPredictionPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a dialog intent prediction pipeline

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
model (SpaceForDialogIntent): a model instance
preprocessor (DialogIntentPredictionPreprocessor): a preprocessor instance
"""

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
model = model if isinstance(
model, SpaceForDialogIntent) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = DialogIntentPredictionPreprocessor(model.model_dir)
self.model = model
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.categories = preprocessor.categories

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:


+ 14
- 8
modelscope/pipelines/nlp/dialog_modeling_pipeline.py View File

@@ -1,8 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict, Optional
from typing import Any, Dict, Union

from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SpaceForDialogModeling
from ...preprocessors import DialogModelingPreprocessor
from ...utils.constant import Tasks
@@ -17,17 +18,22 @@ __all__ = ['DialogModelingPipeline']
Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling)
class DialogModelingPipeline(Pipeline):

def __init__(self, model: SpaceForDialogModeling,
preprocessor: DialogModelingPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
def __init__(self,
model: Union[SpaceForDialogModeling, str],
preprocessor: DialogModelingPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
model (SpaceForDialogModeling): a model instance
preprocessor (DialogModelingPreprocessor): a preprocessor instance
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
model = model if isinstance(
model, SpaceForDialogModeling) else Model.from_pretrained(model)
self.model = model
if preprocessor is None:
preprocessor = DialogModelingPreprocessor(model.model_dir)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.preprocessor = preprocessor

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:


+ 17
- 10
modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py View File

@@ -1,7 +1,7 @@
from typing import Any, Dict
from typing import Any, Dict, Union

from ...metainfo import Pipelines
from ...models import SpaceForDialogStateTracking
from ...models import Model, SpaceForDialogStateTracking
from ...preprocessors import DialogStateTrackingPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline
@@ -15,17 +15,26 @@ __all__ = ['DialogStateTrackingPipeline']
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking)
class DialogStateTrackingPipeline(Pipeline):

def __init__(self, model: SpaceForDialogStateTracking,
preprocessor: DialogStateTrackingPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
def __init__(self,
model: Union[SpaceForDialogStateTracking, str],
preprocessor: DialogStateTrackingPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a dialog state tracking pipeline for
observation of dialog states tracking after many turns of open domain dialogue

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
model (SpaceForDialogStateTracking): a model instance
preprocessor (DialogStateTrackingPreprocessor): a preprocessor instance
"""

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
model = model if isinstance(
model,
SpaceForDialogStateTracking) else Model.from_pretrained(model)
self.model = model
if preprocessor is None:
preprocessor = DialogStateTrackingPreprocessor(model.model_dir)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

self.tokenizer = preprocessor.tokenizer
self.config = preprocessor.config

@@ -46,9 +55,7 @@ class DialogStateTrackingPipeline(Pipeline):
values = inputs['values']
inform = inputs['inform']
prefix = inputs['prefix']
# ds = {slot: 'none' for slot in self.config.dst_slot_list}
ds = inputs['ds']

ds = predict_and_format(self.config, self.tokenizer, _inputs,
_outputs[2], _outputs[3], _outputs[4],
_outputs[5], unique_ids, input_ids_unmasked,


+ 0
- 7
modelscope/pipelines/outputs.py View File

@@ -138,13 +138,6 @@ TASK_OUTPUTS = {
# }
Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS],

# sentiment classification result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.sentiment_classification: ['scores', 'labels'],

# zero-shot classification result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]


+ 15
- 1
tests/pipelines/test_dialog_intent_prediction.py View File

@@ -18,7 +18,7 @@ class DialogIntentPredictionTest(unittest.TestCase):
]

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
model = SpaceForDialogIntent(
@@ -56,6 +56,20 @@ class DialogIntentPredictionTest(unittest.TestCase):
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id)
]
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [pipeline(task=Tasks.dialog_intent_prediction)]
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))


if __name__ == '__main__':
unittest.main()

+ 33
- 21
tests/pipelines/test_dialog_modeling.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from typing import List

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
@@ -89,8 +90,22 @@ class DialogModelingTest(unittest.TestCase):
}
}

def generate_and_print_dialog_response(
self, pipelines: List[DialogModelingPipeline]):

result = {}
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))

result = pipelines[step % 2]({
'user_input': user,
'history': result
})
print('response : {}'.format(result['response']))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
def test_run_by_direct_model_download(self):

cache_path = snapshot_download(self.model_id)

@@ -106,17 +121,7 @@ class DialogModelingTest(unittest.TestCase):
model=model,
preprocessor=preprocessor)
]

result = {}
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))

result = pipelines[step % 2]({
'user_input': user,
'history': result
})
print('response : {}'.format(result['response']))
self.generate_and_print_dialog_response(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
@@ -131,16 +136,23 @@ class DialogModelingTest(unittest.TestCase):
preprocessor=preprocessor)
]

result = {}
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))
self.generate_and_print_dialog_response(pipelines)

result = pipelines[step % 2]({
'user_input': user,
'history': result
})
print('response : {}'.format(result['response']))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(task=Tasks.dialog_modeling, model=self.model_id),
pipeline(task=Tasks.dialog_modeling, model=self.model_id)
]
self.generate_and_print_dialog_response(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [
pipeline(task=Tasks.dialog_modeling),
pipeline(task=Tasks.dialog_modeling)
]
self.generate_and_print_dialog_response(pipelines)


if __name__ == '__main__':


+ 32
- 30
tests/pipelines/test_dialog_state_tracking.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from typing import List

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model, SpaceForDialogStateTracking
@@ -75,23 +76,10 @@ class DialogStateTrackingTest(unittest.TestCase):
'User-8': 'Thank you, goodbye',
}]

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)

model = SpaceForDialogStateTracking(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
pipelines = [
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_state_tracking,
model=model,
preprocessor=preprocessor)
]

pipelines_len = len(pipelines)
def tracking_and_print_dialog_states(
self, pipelines: List[DialogStateTrackingPipeline]):
import json
pipelines_len = len(pipelines)
history_states = [{}]
utter = {}
for step, item in enumerate(self.test_case):
@@ -106,6 +94,22 @@ class DialogStateTrackingTest(unittest.TestCase):

history_states.extend([result['dialog_states'], {}])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)

model = SpaceForDialogStateTracking(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
pipelines = [
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_state_tracking,
model=model,
preprocessor=preprocessor)
]
self.tracking_and_print_dialog_states(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
@@ -120,21 +124,19 @@ class DialogStateTrackingTest(unittest.TestCase):
preprocessor=preprocessor)
]

pipelines_len = len(pipelines)
import json
history_states = [{}]
utter = {}
for step, item in enumerate(self.test_case):
utter.update(item)
result = pipelines[step % pipelines_len]({
'utter':
utter,
'history_states':
history_states
})
print(json.dumps(result))
self.tracking_and_print_dialog_states(pipelines)

history_states.extend([result['dialog_states'], {}])
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(task=Tasks.dialog_state_tracking, model=self.model_id)
]
self.tracking_and_print_dialog_states(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [pipeline(task=Tasks.dialog_state_tracking)]
self.tracking_and_print_dialog_states(pipelines)


if __name__ == '__main__':


Loading…
Cancel
Save