@@ -10,6 +10,5 @@ from .multi_modal import OfaForImageCaptioning | |||||
from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, | from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, | ||||
SbertForSentenceSimilarity, SbertForSentimentClassification, | SbertForSentenceSimilarity, SbertForSentimentClassification, | ||||
SbertForTokenClassification, SpaceForDialogIntentModel, | SbertForTokenClassification, SpaceForDialogIntentModel, | ||||
SpaceForDialogModelingModel, | |||||
SpaceForDialogStateTrackingModel, StructBertForMaskedLM, | |||||
VecoForMaskedLM) | |||||
SpaceForDialogModelingModel, SpaceForDialogStateTracking, | |||||
StructBertForMaskedLM, VecoForMaskedLM) |
@@ -6,11 +6,11 @@ from ....utils.nlp.space.utils_dst import batch_to_device | |||||
from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
from ...builder import MODELS | from ...builder import MODELS | ||||
__all__ = ['SpaceForDialogStateTrackingModel'] | |||||
__all__ = ['SpaceForDialogStateTracking'] | |||||
@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space') | @MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space') | ||||
class SpaceForDialogStateTrackingModel(Model): | |||||
class SpaceForDialogStateTracking(Model): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
"""initialize the test generation model from the `model_dir` path. | """initialize the test generation model from the `model_dir` path. | ||||
@@ -1,7 +1,7 @@ | |||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models import SpaceForDialogStateTrackingModel | |||||
from ...models import SpaceForDialogStateTracking | |||||
from ...preprocessors import DialogStateTrackingPreprocessor | from ...preprocessors import DialogStateTrackingPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Pipeline | from ..base import Pipeline | ||||
@@ -14,7 +14,7 @@ __all__ = ['DialogStateTrackingPipeline'] | |||||
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | ||||
class DialogStateTrackingPipeline(Pipeline): | class DialogStateTrackingPipeline(Pipeline): | ||||
def __init__(self, model: SpaceForDialogStateTrackingModel, | |||||
def __init__(self, model: SpaceForDialogStateTracking, | |||||
preprocessor: DialogStateTrackingPreprocessor, **kwargs): | preprocessor: DialogStateTrackingPreprocessor, **kwargs): | ||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
@@ -5,7 +5,7 @@ import tempfile | |||||
import unittest | import unittest | ||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.models import Model, SpaceForDialogStateTrackingModel | |||||
from modelscope.models import Model, SpaceForDialogStateTracking | |||||
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline | from modelscope.pipelines import DialogStateTrackingPipeline, pipeline | ||||
from modelscope.preprocessors import DialogStateTrackingPreprocessor | from modelscope.preprocessors import DialogStateTrackingPreprocessor | ||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
@@ -81,7 +81,7 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking' | cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking' | ||||
# cache_path = snapshot_download(self.model_id) | # cache_path = snapshot_download(self.model_id) | ||||
model = SpaceForDialogStateTrackingModel(cache_path) | |||||
model = SpaceForDialogStateTracking(cache_path) | |||||
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | ||||
pipelines = [ | pipelines = [ | ||||
DialogStateTrackingPipeline( | DialogStateTrackingPipeline( | ||||
@@ -94,20 +94,19 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
pipelines_len = len(pipelines) | pipelines_len = len(pipelines) | ||||
import json | import json | ||||
for _test_case in self.test_case: | |||||
history_states = [{}] | |||||
utter = {} | |||||
for step, item in enumerate(_test_case): | |||||
utter.update(item) | |||||
result = pipelines[step % pipelines_len]({ | |||||
'utter': | |||||
utter, | |||||
'history_states': | |||||
history_states | |||||
}) | |||||
print(json.dumps(result)) | |||||
history_states = [{}] | |||||
utter = {} | |||||
for step, item in enumerate(self.test_case): | |||||
utter.update(item) | |||||
result = pipelines[step % pipelines_len]({ | |||||
'utter': | |||||
utter, | |||||
'history_states': | |||||
history_states | |||||
}) | |||||
print(json.dumps(result)) | |||||
history_states.extend([result['dialog_states'], {}]) | |||||
history_states.extend([result['dialog_states'], {}]) | |||||
@unittest.skip('test with snapshot_download') | @unittest.skip('test with snapshot_download') | ||||
def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||