@@ -63,15 +63,16 @@ class SpaceForDialogIntent(Model): | |||||
"""return the result by the model | """return the result by the model | ||||
Args: | Args: | ||||
input (Dict[str, Any]): the preprocessed data | |||||
input (Dict[str, Tensor]): the preprocessed data | |||||
Returns: | Returns: | ||||
Dict[str, np.ndarray]: results | |||||
Dict[str, Tensor]: results | |||||
Example: | Example: | ||||
{ | { | ||||
'predictions': array([1]), # lable 0-negative 1-positive | |||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05 | |||||
1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 | |||||
6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 | |||||
2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32) | |||||
} | } | ||||
""" | """ | ||||
import numpy as np | import numpy as np | ||||
@@ -62,15 +62,17 @@ class SpaceForDialogModeling(Model): | |||||
"""return the result by the model | """return the result by the model | ||||
Args: | Args: | ||||
input (Dict[str, Any]): the preprocessed data | |||||
input (Dict[str, Tensor]): the preprocessed data | |||||
Returns: | Returns: | ||||
Dict[str, np.ndarray]: results | |||||
Dict[str, Tensor]: results | |||||
Example: | Example: | ||||
{ | { | ||||
'predictions': array([1]), # lable 0-negative 1-positive | |||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
'labels': array([1,192,321,12]), # lable | |||||
'resp': array([293,1023,123,1123]), #vocab label for response | |||||
'bspn': array([123,321,2,24,1 ]), | |||||
'aspn': array([47,8345,32,29,1983]), | |||||
'db': array([19, 24, 20]), | |||||
} | } | ||||
""" | """ | ||||
@@ -2,6 +2,7 @@ import os | |||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
from ....metainfo import Models | |||||
from ....utils.nlp.space.utils_dst import batch_to_device | from ....utils.nlp.space.utils_dst import batch_to_device | ||||
from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
from ...builder import MODELS | from ...builder import MODELS | ||||
@@ -9,7 +10,7 @@ from ...builder import MODELS | |||||
__all__ = ['SpaceForDialogStateTracking'] | __all__ = ['SpaceForDialogStateTracking'] | ||||
@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space') | |||||
@MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space) | |||||
class SpaceForDialogStateTracking(Model): | class SpaceForDialogStateTracking(Model): | ||||
def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
@@ -17,8 +18,6 @@ class SpaceForDialogStateTracking(Model): | |||||
Args: | Args: | ||||
model_dir (str): the model path. | model_dir (str): the model path. | ||||
model_cls (Optional[Any], optional): model loader, if None, use the | |||||
default loader to load model weights, by default None. | |||||
""" | """ | ||||
super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
@@ -27,7 +26,6 @@ class SpaceForDialogStateTracking(Model): | |||||
self.model_dir = model_dir | self.model_dir = model_dir | ||||
self.config = SpaceConfig.from_pretrained(self.model_dir) | self.config = SpaceConfig.from_pretrained(self.model_dir) | ||||
# self.model = SpaceForDST(self.config) | |||||
self.model = SpaceForDST.from_pretrained(self.model_dir) | self.model = SpaceForDST.from_pretrained(self.model_dir) | ||||
self.model.to(self.config.device) | self.model.to(self.config.device) | ||||
@@ -35,15 +33,20 @@ class SpaceForDialogStateTracking(Model): | |||||
"""return the result by the model | """return the result by the model | ||||
Args: | Args: | ||||
input (Dict[str, Any]): the preprocessed data | |||||
input (Dict[str, Tensor]): the preprocessed data | |||||
Returns: | Returns: | ||||
Dict[str, np.ndarray]: results | |||||
Dict[str, Tensor]: results | |||||
Example: | Example: | ||||
{ | { | ||||
'predictions': array([1]), # lable 0-negative 1-positive | |||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
'inputs': dict(input_ids, input_masks,start_pos), # tracking states | |||||
'outputs': dict(slots_logits), | |||||
'unique_ids': str(test-example.json-0), # default value | |||||
'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) | |||||
'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||||
'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||||
'prefix': str('final'), #default value | |||||
'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) | |||||
} | } | ||||
""" | """ | ||||
import numpy as np | import numpy as np | ||||
@@ -88,8 +91,6 @@ class SpaceForDialogStateTracking(Model): | |||||
if u != 0: | if u != 0: | ||||
diag_state[slot][i] = u | diag_state[slot][i] = u | ||||
# print(outputs) | |||||
return { | return { | ||||
'inputs': inputs, | 'inputs': inputs, | ||||
'outputs': outputs, | 'outputs': outputs, | ||||
@@ -41,6 +41,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
'damo/nlp_space_dialog-intent-prediction'), | 'damo/nlp_space_dialog-intent-prediction'), | ||||
Tasks.dialog_modeling: (Pipelines.dialog_modeling, | Tasks.dialog_modeling: (Pipelines.dialog_modeling, | ||||
'damo/nlp_space_dialog-modeling'), | 'damo/nlp_space_dialog-modeling'), | ||||
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||||
'damo/nlp_space_dialog-state-tracking'), | |||||
Tasks.image_captioning: (Pipelines.image_caption, | Tasks.image_captioning: (Pipelines.image_caption, | ||||
'damo/ofa_image-caption_coco_large_en'), | 'damo/ofa_image-caption_coco_large_en'), | ||||
Tasks.image_generation: | Tasks.image_generation: | ||||
@@ -1,8 +1,9 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from typing import Any, Dict | |||||
from typing import Any, Dict, Union | |||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models import Model | |||||
from ...models.nlp import SpaceForDialogIntent | from ...models.nlp import SpaceForDialogIntent | ||||
from ...preprocessors import DialogIntentPredictionPreprocessor | from ...preprocessors import DialogIntentPredictionPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
@@ -18,17 +19,22 @@ __all__ = ['DialogIntentPredictionPipeline'] | |||||
module_name=Pipelines.dialog_intent_prediction) | module_name=Pipelines.dialog_intent_prediction) | ||||
class DialogIntentPredictionPipeline(Pipeline): | class DialogIntentPredictionPipeline(Pipeline): | ||||
def __init__(self, model: SpaceForDialogIntent, | |||||
preprocessor: DialogIntentPredictionPreprocessor, **kwargs): | |||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
def __init__(self, | |||||
model: Union[SpaceForDialogIntent, str], | |||||
preprocessor: DialogIntentPredictionPreprocessor = None, | |||||
**kwargs): | |||||
"""use `model` and `preprocessor` to create a dialog intent prediction pipeline | |||||
Args: | Args: | ||||
model (SequenceClassificationModel): a model instance | |||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
model (SpaceForDialogIntent): a model instance | |||||
preprocessor (DialogIntentPredictionPreprocessor): a preprocessor instance | |||||
""" | """ | ||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
model = model if isinstance( | |||||
model, SpaceForDialogIntent) else Model.from_pretrained(model) | |||||
if preprocessor is None: | |||||
preprocessor = DialogIntentPredictionPreprocessor(model.model_dir) | |||||
self.model = model | self.model = model | ||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
self.categories = preprocessor.categories | self.categories = preprocessor.categories | ||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | ||||
@@ -1,8 +1,9 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from typing import Any, Dict, Optional | |||||
from typing import Any, Dict, Union | |||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models import Model | |||||
from ...models.nlp import SpaceForDialogModeling | from ...models.nlp import SpaceForDialogModeling | ||||
from ...preprocessors import DialogModelingPreprocessor | from ...preprocessors import DialogModelingPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
@@ -17,17 +18,22 @@ __all__ = ['DialogModelingPipeline'] | |||||
Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | ||||
class DialogModelingPipeline(Pipeline): | class DialogModelingPipeline(Pipeline): | ||||
def __init__(self, model: SpaceForDialogModeling, | |||||
preprocessor: DialogModelingPreprocessor, **kwargs): | |||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
def __init__(self, | |||||
model: Union[SpaceForDialogModeling, str], | |||||
preprocessor: DialogModelingPreprocessor = None, | |||||
**kwargs): | |||||
"""use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation | |||||
Args: | Args: | ||||
model (SequenceClassificationModel): a model instance | |||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
model (SpaceForDialogModeling): a model instance | |||||
preprocessor (DialogModelingPreprocessor): a preprocessor instance | |||||
""" | """ | ||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
model = model if isinstance( | |||||
model, SpaceForDialogModeling) else Model.from_pretrained(model) | |||||
self.model = model | self.model = model | ||||
if preprocessor is None: | |||||
preprocessor = DialogModelingPreprocessor(model.model_dir) | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
self.preprocessor = preprocessor | self.preprocessor = preprocessor | ||||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | ||||
@@ -1,7 +1,7 @@ | |||||
from typing import Any, Dict | |||||
from typing import Any, Dict, Union | |||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models import SpaceForDialogStateTracking | |||||
from ...models import Model, SpaceForDialogStateTracking | |||||
from ...preprocessors import DialogStateTrackingPreprocessor | from ...preprocessors import DialogStateTrackingPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Pipeline | from ..base import Pipeline | ||||
@@ -15,17 +15,26 @@ __all__ = ['DialogStateTrackingPipeline'] | |||||
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | ||||
class DialogStateTrackingPipeline(Pipeline): | class DialogStateTrackingPipeline(Pipeline): | ||||
def __init__(self, model: SpaceForDialogStateTracking, | |||||
preprocessor: DialogStateTrackingPreprocessor, **kwargs): | |||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
def __init__(self, | |||||
model: Union[SpaceForDialogStateTracking, str], | |||||
preprocessor: DialogStateTrackingPreprocessor = None, | |||||
**kwargs): | |||||
"""use `model` and `preprocessor` to create a dialog state tracking pipeline for | |||||
observation of dialog states tracking after many turns of open domain dialogue | |||||
Args: | Args: | ||||
model (SequenceClassificationModel): a model instance | |||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
model (SpaceForDialogStateTracking): a model instance | |||||
preprocessor (DialogStateTrackingPreprocessor): a preprocessor instance | |||||
""" | """ | ||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
model = model if isinstance( | |||||
model, | |||||
SpaceForDialogStateTracking) else Model.from_pretrained(model) | |||||
self.model = model | self.model = model | ||||
if preprocessor is None: | |||||
preprocessor = DialogStateTrackingPreprocessor(model.model_dir) | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
self.tokenizer = preprocessor.tokenizer | self.tokenizer = preprocessor.tokenizer | ||||
self.config = preprocessor.config | self.config = preprocessor.config | ||||
@@ -46,9 +55,7 @@ class DialogStateTrackingPipeline(Pipeline): | |||||
values = inputs['values'] | values = inputs['values'] | ||||
inform = inputs['inform'] | inform = inputs['inform'] | ||||
prefix = inputs['prefix'] | prefix = inputs['prefix'] | ||||
# ds = {slot: 'none' for slot in self.config.dst_slot_list} | |||||
ds = inputs['ds'] | ds = inputs['ds'] | ||||
ds = predict_and_format(self.config, self.tokenizer, _inputs, | ds = predict_and_format(self.config, self.tokenizer, _inputs, | ||||
_outputs[2], _outputs[3], _outputs[4], | _outputs[2], _outputs[3], _outputs[4], | ||||
_outputs[5], unique_ids, input_ids_unmasked, | _outputs[5], unique_ids, input_ids_unmasked, | ||||
@@ -138,13 +138,6 @@ TASK_OUTPUTS = { | |||||
# } | # } | ||||
Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
# sentiment classification result for single sample | |||||
# { | |||||
# "labels": ["happy", "sad", "calm", "angry"], | |||||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||||
# } | |||||
Tasks.sentiment_classification: ['scores', 'labels'], | |||||
# zero-shot classification result for single sample | # zero-shot classification result for single sample | ||||
# { | # { | ||||
# "scores": [0.9, 0.1, 0.05, 0.05] | # "scores": [0.9, 0.1, 0.05, 0.05] | ||||
@@ -18,7 +18,7 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||||
] | ] | ||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
def test_run(self): | |||||
def test_run_by_direct_model_download(self): | |||||
cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | ||||
model = SpaceForDialogIntent( | model = SpaceForDialogIntent( | ||||
@@ -56,6 +56,20 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||||
for my_pipeline, item in list(zip(pipelines, self.test_case)): | for my_pipeline, item in list(zip(pipelines, self.test_case)): | ||||
print(my_pipeline(item)) | print(my_pipeline(item)) | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_name(self): | |||||
pipelines = [ | |||||
pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id) | |||||
] | |||||
for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
print(my_pipeline(item)) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_default_model(self): | |||||
pipelines = [pipeline(task=Tasks.dialog_intent_prediction)] | |||||
for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
print(my_pipeline(item)) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
unittest.main() | unittest.main() |
@@ -1,5 +1,6 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import unittest | import unittest | ||||
from typing import List | |||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.models import Model | from modelscope.models import Model | ||||
@@ -89,8 +90,22 @@ class DialogModelingTest(unittest.TestCase): | |||||
} | } | ||||
} | } | ||||
def generate_and_print_dialog_response( | |||||
self, pipelines: List[DialogModelingPipeline]): | |||||
result = {} | |||||
for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
user = item['user'] | |||||
print('user: {}'.format(user)) | |||||
result = pipelines[step % 2]({ | |||||
'user_input': user, | |||||
'history': result | |||||
}) | |||||
print('response : {}'.format(result['response'])) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
def test_run(self): | |||||
def test_run_by_direct_model_download(self): | |||||
cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
@@ -106,17 +121,7 @@ class DialogModelingTest(unittest.TestCase): | |||||
model=model, | model=model, | ||||
preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
] | ] | ||||
result = {} | |||||
for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
user = item['user'] | |||||
print('user: {}'.format(user)) | |||||
result = pipelines[step % 2]({ | |||||
'user_input': user, | |||||
'history': result | |||||
}) | |||||
print('response : {}'.format(result['response'])) | |||||
self.generate_and_print_dialog_response(pipelines) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
@@ -131,16 +136,23 @@ class DialogModelingTest(unittest.TestCase): | |||||
preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
] | ] | ||||
result = {} | |||||
for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
user = item['user'] | |||||
print('user: {}'.format(user)) | |||||
self.generate_and_print_dialog_response(pipelines) | |||||
result = pipelines[step % 2]({ | |||||
'user_input': user, | |||||
'history': result | |||||
}) | |||||
print('response : {}'.format(result['response'])) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_name(self): | |||||
pipelines = [ | |||||
pipeline(task=Tasks.dialog_modeling, model=self.model_id), | |||||
pipeline(task=Tasks.dialog_modeling, model=self.model_id) | |||||
] | |||||
self.generate_and_print_dialog_response(pipelines) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_default_model(self): | |||||
pipelines = [ | |||||
pipeline(task=Tasks.dialog_modeling), | |||||
pipeline(task=Tasks.dialog_modeling) | |||||
] | |||||
self.generate_and_print_dialog_response(pipelines) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
@@ -1,5 +1,6 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import unittest | import unittest | ||||
from typing import List | |||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.models import Model, SpaceForDialogStateTracking | from modelscope.models import Model, SpaceForDialogStateTracking | ||||
@@ -75,23 +76,10 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
'User-8': 'Thank you, goodbye', | 'User-8': 'Thank you, goodbye', | ||||
}] | }] | ||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run(self): | |||||
cache_path = snapshot_download(self.model_id) | |||||
model = SpaceForDialogStateTracking(cache_path) | |||||
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||||
pipelines = [ | |||||
DialogStateTrackingPipeline( | |||||
model=model, preprocessor=preprocessor), | |||||
pipeline( | |||||
task=Tasks.dialog_state_tracking, | |||||
model=model, | |||||
preprocessor=preprocessor) | |||||
] | |||||
pipelines_len = len(pipelines) | |||||
def tracking_and_print_dialog_states( | |||||
self, pipelines: List[DialogStateTrackingPipeline]): | |||||
import json | import json | ||||
pipelines_len = len(pipelines) | |||||
history_states = [{}] | history_states = [{}] | ||||
utter = {} | utter = {} | ||||
for step, item in enumerate(self.test_case): | for step, item in enumerate(self.test_case): | ||||
@@ -106,6 +94,22 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
history_states.extend([result['dialog_states'], {}]) | history_states.extend([result['dialog_states'], {}]) | ||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_by_direct_model_download(self): | |||||
cache_path = snapshot_download(self.model_id) | |||||
model = SpaceForDialogStateTracking(cache_path) | |||||
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||||
pipelines = [ | |||||
DialogStateTrackingPipeline( | |||||
model=model, preprocessor=preprocessor), | |||||
pipeline( | |||||
task=Tasks.dialog_state_tracking, | |||||
model=model, | |||||
preprocessor=preprocessor) | |||||
] | |||||
self.tracking_and_print_dialog_states(pipelines) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
@@ -120,21 +124,19 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
] | ] | ||||
pipelines_len = len(pipelines) | |||||
import json | |||||
history_states = [{}] | |||||
utter = {} | |||||
for step, item in enumerate(self.test_case): | |||||
utter.update(item) | |||||
result = pipelines[step % pipelines_len]({ | |||||
'utter': | |||||
utter, | |||||
'history_states': | |||||
history_states | |||||
}) | |||||
print(json.dumps(result)) | |||||
self.tracking_and_print_dialog_states(pipelines) | |||||
history_states.extend([result['dialog_states'], {}]) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_name(self): | |||||
pipelines = [ | |||||
pipeline(task=Tasks.dialog_state_tracking, model=self.model_id) | |||||
] | |||||
self.tracking_and_print_dialog_states(pipelines) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_default_model(self): | |||||
pipelines = [pipeline(task=Tasks.dialog_state_tracking)] | |||||
self.tracking_and_print_dialog_states(pipelines) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||