@@ -1,5 +1,5 @@ | |||||
MODELSCOPE_URL_SCHEME = 'http://' | MODELSCOPE_URL_SCHEME = 'http://' | ||||
DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330' | |||||
DEFAULT_MODELSCOPE_DOMAIN = '47.94.223.21:31090' | |||||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102' | DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102' | ||||
DEFAULT_MODELSCOPE_GROUP = 'damo' | DEFAULT_MODELSCOPE_GROUP = 'damo' | ||||
@@ -16,6 +16,7 @@ class Models(object): | |||||
palm = 'palm-v2' | palm = 'palm-v2' | ||||
structbert = 'structbert' | structbert = 'structbert' | ||||
veco = 'veco' | veco = 'veco' | ||||
space = 'space' | |||||
# audio models | # audio models | ||||
sambert_hifi_16k = 'sambert-hifi-16k' | sambert_hifi_16k = 'sambert-hifi-16k' | ||||
@@ -67,7 +68,7 @@ class Pipelines(object): | |||||
kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
# multi-modal tasks | # multi-modal tasks | ||||
image_caption = 'image-caption' | |||||
image_caption = 'image-captioning' | |||||
multi_modal_embedding = 'multi-modal-embedding' | multi_modal_embedding = 'multi-modal-embedding' | ||||
visual_question_answering = 'visual-question-answering' | visual_question_answering = 'visual-question-answering' | ||||
@@ -105,6 +106,9 @@ class Preprocessors(object): | |||||
token_cls_tokenizer = 'token-cls-tokenizer' | token_cls_tokenizer = 'token-cls-tokenizer' | ||||
nli_tokenizer = 'nli-tokenizer' | nli_tokenizer = 'nli-tokenizer' | ||||
sen_cls_tokenizer = 'sen-cls-tokenizer' | sen_cls_tokenizer = 'sen-cls-tokenizer' | ||||
dialog_intent_preprocessor = 'dialog-intent-preprocessor' | |||||
dialog_modeling_preprocessor = 'dialog-modeling-preprocessor' | |||||
dialog_state_tracking_preprocessor = 'dialog_state_tracking_preprocessor' | |||||
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | ||||
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | ||||
@@ -9,6 +9,7 @@ from .builder import MODELS, build_model | |||||
from .multi_modal import OfaForImageCaptioning | from .multi_modal import OfaForImageCaptioning | ||||
from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, | from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, | ||||
SbertForSentenceSimilarity, SbertForSentimentClassification, | SbertForSentenceSimilarity, SbertForSentimentClassification, | ||||
SbertForTokenClassification, SpaceForDialogIntentModel, | |||||
SpaceForDialogModelingModel, SpaceForDialogStateTracking, | |||||
StructBertForMaskedLM, VecoForMaskedLM) | |||||
SbertForTokenClassification, SbertForZeroShotClassification, | |||||
SpaceForDialogIntent, SpaceForDialogModeling, | |||||
SpaceForDialogStateTracking, StructBertForMaskedLM, | |||||
VecoForMaskedLM) |
@@ -1,6 +1,9 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | import os | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from ....metainfo import Models | |||||
from ....preprocessors.space.fields.intent_field import IntentBPETextField | from ....preprocessors.space.fields.intent_field import IntentBPETextField | ||||
from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer | from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer | ||||
from ....utils.config import Config | from ....utils.config import Config | ||||
@@ -10,19 +13,18 @@ from ...builder import MODELS | |||||
from .model.generator import Generator | from .model.generator import Generator | ||||
from .model.model_base import SpaceModelBase | from .model.model_base import SpaceModelBase | ||||
__all__ = ['SpaceForDialogIntentModel'] | |||||
__all__ = ['SpaceForDialogIntent'] | |||||
@MODELS.register_module(Tasks.dialog_intent_prediction, module_name=r'space') | |||||
class SpaceForDialogIntentModel(Model): | |||||
@MODELS.register_module( | |||||
Tasks.dialog_intent_prediction, module_name=Models.space) | |||||
class SpaceForDialogIntent(Model): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
"""initialize the test generation model from the `model_dir` path. | """initialize the test generation model from the `model_dir` path. | ||||
Args: | Args: | ||||
model_dir (str): the model path. | model_dir (str): the model path. | ||||
model_cls (Optional[Any], optional): model loader, if None, use the | |||||
default loader to load model weights, by default None. | |||||
""" | """ | ||||
super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
@@ -1,6 +1,9 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | import os | ||||
from typing import Any, Dict, Optional | from typing import Any, Dict, Optional | ||||
from ....metainfo import Models | |||||
from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField | from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField | ||||
from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | ||||
from ....utils.config import Config | from ....utils.config import Config | ||||
@@ -10,19 +13,17 @@ from ...builder import MODELS | |||||
from .model.generator import Generator | from .model.generator import Generator | ||||
from .model.model_base import SpaceModelBase | from .model.model_base import SpaceModelBase | ||||
__all__ = ['SpaceForDialogModelingModel'] | |||||
__all__ = ['SpaceForDialogModeling'] | |||||
@MODELS.register_module(Tasks.dialog_modeling, module_name=r'space') | |||||
class SpaceForDialogModelingModel(Model): | |||||
@MODELS.register_module(Tasks.dialog_modeling, module_name=Models.space) | |||||
class SpaceForDialogModeling(Model): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
"""initialize the test generation model from the `model_dir` path. | """initialize the test generation model from the `model_dir` path. | ||||
Args: | Args: | ||||
model_dir (str): the model path. | model_dir (str): the model path. | ||||
model_cls (Optional[Any], optional): model loader, if None, use the | |||||
default loader to load model weights, by default None. | |||||
""" | """ | ||||
super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
@@ -1,6 +1,5 @@ | |||||
""" | |||||
IntentUnifiedTransformer | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import torch | import torch | ||||
from .unified_transformer import UnifiedTransformer | from .unified_transformer import UnifiedTransformer | ||||
@@ -1,6 +1,4 @@ | |||||
""" | |||||
Generator class. | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import math | import math | ||||
@@ -1,6 +1,5 @@ | |||||
""" | |||||
IntentUnifiedTransformer | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
@@ -1,6 +1,5 @@ | |||||
""" | |||||
Model base | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | import os | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -1,6 +1,4 @@ | |||||
""" | |||||
UnifiedTransformer | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -1,6 +1,4 @@ | |||||
""" | |||||
Embedder class. | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -1,6 +1,4 @@ | |||||
""" | |||||
FeedForward class. | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -1,6 +1,4 @@ | |||||
""" | |||||
Helpful functions. | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -1,6 +1,4 @@ | |||||
""" | |||||
MultiheadAttention class. | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -53,8 +51,6 @@ class MultiheadAttention(nn.Module): | |||||
if mask is not None: | if mask is not None: | ||||
''' | ''' | ||||
mask: [batch size, num_heads, seq_len, seq_len] | mask: [batch size, num_heads, seq_len, seq_len] | ||||
mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中<pad>位看的行 | |||||
导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0 | |||||
>>> F.softmax([-1e10, -100, -100]) | >>> F.softmax([-1e10, -100, -100]) | ||||
>>> [0.00, 0.50, 0.50] | >>> [0.00, 0.50, 0.50] | ||||
@@ -1,6 +1,4 @@ | |||||
""" | |||||
TransformerBlock class. | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -19,4 +19,4 @@ DOWNLOADED_DATASETS_PATH = Path( | |||||
os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH)) | os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH)) | ||||
MS_HUB_ENDPOINT = os.environ.get('MS_HUB_ENDPOINT', | MS_HUB_ENDPOINT = os.environ.get('MS_HUB_ENDPOINT', | ||||
'http://123.57.189.90:31752') | |||||
'http://47.94.223.21:31752') |
@@ -36,6 +36,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.zero_shot_classification: | Tasks.zero_shot_classification: | ||||
(Pipelines.zero_shot_classification, | (Pipelines.zero_shot_classification, | ||||
'damo/nlp_structbert_zero-shot-classification_chinese-base'), | 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | ||||
Tasks.dialog_intent_prediction: | |||||
(Pipelines.dialog_intent_prediction, | |||||
'damo/nlp_space_dialog-intent-prediction'), | |||||
Tasks.dialog_modeling: (Pipelines.dialog_modeling, | |||||
'damo/nlp_space_dialog-modeling'), | |||||
Tasks.image_captioning: (Pipelines.image_caption, | Tasks.image_captioning: (Pipelines.image_caption, | ||||
'damo/ofa_image-caption_coco_large_en'), | 'damo/ofa_image-caption_coco_large_en'), | ||||
Tasks.image_generation: | Tasks.image_generation: | ||||
@@ -1,7 +1,9 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models.nlp import SpaceForDialogIntentModel | |||||
from ...models.nlp import SpaceForDialogIntent | |||||
from ...preprocessors import DialogIntentPredictionPreprocessor | from ...preprocessors import DialogIntentPredictionPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Pipeline | from ..base import Pipeline | ||||
@@ -15,7 +17,7 @@ __all__ = ['DialogIntentPredictionPipeline'] | |||||
module_name=Pipelines.dialog_intent_prediction) | module_name=Pipelines.dialog_intent_prediction) | ||||
class DialogIntentPredictionPipeline(Pipeline): | class DialogIntentPredictionPipeline(Pipeline): | ||||
def __init__(self, model: SpaceForDialogIntentModel, | |||||
def __init__(self, model: SpaceForDialogIntent, | |||||
preprocessor: DialogIntentPredictionPreprocessor, **kwargs): | preprocessor: DialogIntentPredictionPreprocessor, **kwargs): | ||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
@@ -26,7 +28,7 @@ class DialogIntentPredictionPipeline(Pipeline): | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
self.model = model | self.model = model | ||||
# self.tokenizer = preprocessor.tokenizer | |||||
self.categories = preprocessor.categories | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | ||||
"""process the prediction results | """process the prediction results | ||||
@@ -41,6 +43,10 @@ class DialogIntentPredictionPipeline(Pipeline): | |||||
pred = inputs['pred'] | pred = inputs['pred'] | ||||
pos = np.where(pred == np.max(pred)) | pos = np.where(pred == np.max(pred)) | ||||
result = {'pred': pred, 'label': pos[0]} | |||||
result = { | |||||
'pred': pred, | |||||
'label_pos': pos[0], | |||||
'label': self.categories[pos[0][0]] | |||||
} | |||||
return result | return result |
@@ -1,7 +1,9 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import Any, Dict, Optional | from typing import Any, Dict, Optional | ||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models.nlp import SpaceForDialogModelingModel | |||||
from ...models.nlp import SpaceForDialogModeling | |||||
from ...preprocessors import DialogModelingPreprocessor | from ...preprocessors import DialogModelingPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Pipeline, Tensor | from ..base import Pipeline, Tensor | ||||
@@ -14,7 +16,7 @@ __all__ = ['DialogModelingPipeline'] | |||||
Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | ||||
class DialogModelingPipeline(Pipeline): | class DialogModelingPipeline(Pipeline): | ||||
def __init__(self, model: SpaceForDialogModelingModel, | |||||
def __init__(self, model: SpaceForDialogModeling, | |||||
preprocessor: DialogModelingPreprocessor, **kwargs): | preprocessor: DialogModelingPreprocessor, **kwargs): | ||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
@@ -40,7 +42,6 @@ class DialogModelingPipeline(Pipeline): | |||||
inputs['resp']) | inputs['resp']) | ||||
assert len(sys_rsp) > 2 | assert len(sys_rsp) > 2 | ||||
sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | ||||
# sys_rsp = self.preprocessor.text_field.tokenizer. | |||||
inputs['sys'] = sys_rsp | inputs['sys'] = sys_rsp | ||||
@@ -108,14 +108,19 @@ TASK_OUTPUTS = { | |||||
# } | # } | ||||
Tasks.sentiment_classification: ['scores', 'labels'], | Tasks.sentiment_classification: ['scores', 'labels'], | ||||
# zero-shot classification result for single sample | |||||
# { | |||||
# "labels": ["happy", "sad", "calm", "angry"], | |||||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||||
# } | |||||
Tasks.zero_shot_classification: ['scores', 'labels'], | |||||
# nli result for single sample | # nli result for single sample | ||||
# { | # { | ||||
# "labels": ["happy", "sad", "calm", "angry"], | # "labels": ["happy", "sad", "calm", "angry"], | ||||
# "scores": [0.9, 0.1, 0.05, 0.05] | # "scores": [0.9, 0.1, 0.05, 0.05] | ||||
# } | # } | ||||
Tasks.nli: ['scores', 'labels'], | Tasks.nli: ['scores', 'labels'], | ||||
Tasks.dialog_modeling: [], | |||||
Tasks.dialog_intent_prediction: [], | |||||
# { | # { | ||||
# "dialog_states": { | # "dialog_states": { | ||||
@@ -153,6 +158,31 @@ TASK_OUTPUTS = { | |||||
# } | # } | ||||
Tasks.dialog_state_tracking: ['dialog_states'], | Tasks.dialog_state_tracking: ['dialog_states'], | ||||
# {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, | |||||
# 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, | |||||
# 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, | |||||
# 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, | |||||
# 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05, | |||||
# 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05, | |||||
# 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05, | |||||
# 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05, | |||||
# 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03, | |||||
# 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05, | |||||
# 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05, | |||||
# 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05, | |||||
# 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05, | |||||
# 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05, | |||||
# 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05, | |||||
# 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05, | |||||
# 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, | |||||
# 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, | |||||
# 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, | |||||
# 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} | |||||
Tasks.dialog_intent_prediction: ['pred', 'label_pos', 'label'], | |||||
# sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] | |||||
Tasks.dialog_modeling: ['sys'], | |||||
# ============ audio tasks =================== | # ============ audio tasks =================== | ||||
# audio processed for single file in PCM format | # audio processed for single file in PCM format | ||||
@@ -15,7 +15,7 @@ __all__ = [ | |||||
'Tokenize', 'SequenceClassificationPreprocessor', | 'Tokenize', 'SequenceClassificationPreprocessor', | ||||
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | ||||
'NLIPreprocessor', 'SentimentClassificationPreprocessor', | 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | ||||
'FillMaskPreprocessor' | |||||
'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor' | |||||
] | ] | ||||
@@ -421,3 +421,47 @@ class TokenClassifcationPreprocessor(Preprocessor): | |||||
'attention_mask': attention_mask, | 'attention_mask': attention_mask, | ||||
'token_type_ids': token_type_ids | 'token_type_ids': token_type_ids | ||||
} | } | ||||
@PREPROCESSORS.register_module( | |||||
Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) | |||||
class ZeroShotClassificationPreprocessor(Preprocessor): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||||
Args: | |||||
model_dir (str): model path | |||||
""" | |||||
super().__init__(*args, **kwargs) | |||||
from sofa import SbertTokenizer | |||||
self.model_dir: str = model_dir | |||||
self.sequence_length = kwargs.pop('sequence_length', 512) | |||||
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||||
@type_assert(object, str) | |||||
def __call__(self, data: str, hypothesis_template: str, | |||||
candidate_labels: list) -> Dict[str, Any]: | |||||
"""process the raw input data | |||||
Args: | |||||
data (str): a sentence | |||||
Example: | |||||
'you are so handsome.' | |||||
Returns: | |||||
Dict[str, Any]: the preprocessed data | |||||
""" | |||||
pairs = [[data, hypothesis_template.format(label)] | |||||
for label in candidate_labels] | |||||
features = self.tokenizer( | |||||
pairs, | |||||
padding=True, | |||||
truncation=True, | |||||
max_length=self.sequence_length, | |||||
return_tensors='pt', | |||||
truncation_strategy='only_first') | |||||
return features |
@@ -3,6 +3,9 @@ | |||||
import os | import os | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
import json | |||||
from ...metainfo import Preprocessors | |||||
from ...utils.config import Config | from ...utils.config import Config | ||||
from ...utils.constant import Fields, ModelFile | from ...utils.constant import Fields, ModelFile | ||||
from ...utils.type_assert import type_assert | from ...utils.type_assert import type_assert | ||||
@@ -13,7 +16,8 @@ from .fields.intent_field import IntentBPETextField | |||||
__all__ = ['DialogIntentPredictionPreprocessor'] | __all__ = ['DialogIntentPredictionPreprocessor'] | ||||
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-intent') | |||||
@PREPROCESSORS.register_module( | |||||
Fields.nlp, module_name=Preprocessors.dialog_intent_preprocessor) | |||||
class DialogIntentPredictionPreprocessor(Preprocessor): | class DialogIntentPredictionPreprocessor(Preprocessor): | ||||
def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
@@ -30,6 +34,11 @@ class DialogIntentPredictionPreprocessor(Preprocessor): | |||||
self.text_field = IntentBPETextField( | self.text_field = IntentBPETextField( | ||||
self.model_dir, config=self.config) | self.model_dir, config=self.config) | ||||
self.categories = None | |||||
with open(os.path.join(self.model_dir, 'categories.json'), 'r') as f: | |||||
self.categories = json.load(f) | |||||
assert len(self.categories) == 77 | |||||
@type_assert(object, str) | @type_assert(object, str) | ||||
def __call__(self, data: str) -> Dict[str, Any]: | def __call__(self, data: str) -> Dict[str, Any]: | ||||
"""process the raw input data | """process the raw input data | ||||
@@ -3,6 +3,7 @@ | |||||
import os | import os | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from ...metainfo import Preprocessors | |||||
from ...utils.config import Config | from ...utils.config import Config | ||||
from ...utils.constant import Fields, ModelFile | from ...utils.constant import Fields, ModelFile | ||||
from ...utils.type_assert import type_assert | from ...utils.type_assert import type_assert | ||||
@@ -13,7 +14,8 @@ from .fields.gen_field import MultiWOZBPETextField | |||||
__all__ = ['DialogModelingPreprocessor'] | __all__ = ['DialogModelingPreprocessor'] | ||||
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-modeling') | |||||
@PREPROCESSORS.register_module( | |||||
Fields.nlp, module_name=Preprocessors.dialog_modeling_preprocessor) | |||||
class DialogModelingPreprocessor(Preprocessor): | class DialogModelingPreprocessor(Preprocessor): | ||||
def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
@@ -1,6 +1,5 @@ | |||||
""" | |||||
Field class | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | import os | ||||
import random | import random | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
@@ -8,7 +7,6 @@ from itertools import chain | |||||
import numpy as np | import numpy as np | ||||
from ....utils.constant import ModelFile | |||||
from ....utils.nlp.space import ontology, utils | from ....utils.nlp.space import ontology, utils | ||||
from ....utils.nlp.space.db_ops import MultiWozDB | from ....utils.nlp.space.db_ops import MultiWozDB | ||||
from ....utils.nlp.space.utils import list2np | from ....utils.nlp.space.utils import list2np | ||||
@@ -1,6 +1,5 @@ | |||||
""" | |||||
Intent Field class | |||||
""" | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import glob | import glob | ||||
import multiprocessing | import multiprocessing | ||||
import os | import os | ||||
@@ -308,14 +308,6 @@ if __name__ == '__main__': | |||||
'attraction': 5, | 'attraction': 5, | ||||
'train': 1, | 'train': 1, | ||||
} | } | ||||
# for ent in res: | |||||
# if reidx.get(domain): | |||||
# report.append(ent[reidx[domain]]) | |||||
# for ent in res: | |||||
# if 'name' in ent: | |||||
# report.append(ent['name']) | |||||
# if 'trainid' in ent: | |||||
# report.append(ent['trainid']) | |||||
print(constraints) | print(constraints) | ||||
print(res) | print(res) | ||||
print('count:', len(res), '\nnames:', report) | print('count:', len(res), '\nnames:', report) |
@@ -123,19 +123,6 @@ dialog_act_all_slots = all_slots + ['choice', 'open'] | |||||
# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] | # no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] | ||||
slot_name_to_slot_token = {} | slot_name_to_slot_token = {} | ||||
# special slot tokens in responses | |||||
# not use at the momoent | |||||
slot_name_to_value_token = { | |||||
# 'entrance fee': '[value_price]', | |||||
# 'pricerange': '[value_price]', | |||||
# 'arriveby': '[value_time]', | |||||
# 'leaveat': '[value_time]', | |||||
# 'departure': '[value_place]', | |||||
# 'destination': '[value_place]', | |||||
# 'stay': 'count', | |||||
# 'people': 'count' | |||||
} | |||||
# eos tokens definition | # eos tokens definition | ||||
eos_tokens = { | eos_tokens = { | ||||
'user': '<eos_u>', | 'user': '<eos_u>', | ||||
@@ -53,16 +53,9 @@ def clean_replace(s, r, t, forward=True, backward=False): | |||||
return s, -1 | return s, -1 | ||||
return s[:idx] + t + s[idx_r:], idx_r | return s[:idx] + t + s[idx_r:], idx_r | ||||
# source, replace, target = s, r, t | |||||
# count = 0 | |||||
sidx = 0 | sidx = 0 | ||||
while sidx != -1: | while sidx != -1: | ||||
s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) | s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) | ||||
# count += 1 | |||||
# print(s, sidx) | |||||
# if count == 20: | |||||
# print(source, '\n', replace, '\n', target) | |||||
# quit() | |||||
return s | return s | ||||
@@ -193,14 +186,3 @@ class MultiWOZVocab(object): | |||||
return self._idx2word[idx] | return self._idx2word[idx] | ||||
else: | else: | ||||
return self._idx2word[idx] + '(o)' | return self._idx2word[idx] + '(o)' | ||||
# def sentence_decode(self, index_list, eos=None, indicate_oov=False): | |||||
# l = [self.decode(_, indicate_oov) for _ in index_list] | |||||
# if not eos or eos not in l: | |||||
# return ' '.join(l) | |||||
# else: | |||||
# idx = l.index(eos) | |||||
# return ' '.join(l[:idx]) | |||||
# | |||||
# def nl_decode(self, l, eos=None): | |||||
# return [self.sentence_decode(_, eos) + '\n' for _ in l] |
@@ -3,10 +3,11 @@ import unittest | |||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.models import Model | from modelscope.models import Model | ||||
from modelscope.models.nlp import SpaceForDialogIntentModel | |||||
from modelscope.models.nlp import SpaceForDialogIntent | |||||
from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline | from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline | ||||
from modelscope.preprocessors import DialogIntentPredictionPreprocessor | from modelscope.preprocessors import DialogIntentPredictionPreprocessor | ||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
from modelscope.utils.test_utils import test_level | |||||
class DialogIntentPredictionTest(unittest.TestCase): | class DialogIntentPredictionTest(unittest.TestCase): | ||||
@@ -16,11 +17,11 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||||
'I still have not received my new card, I ordered over a week ago.' | 'I still have not received my new card, I ordered over a week ago.' | ||||
] | ] | ||||
@unittest.skip('test with snapshot_download') | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run(self): | def test_run(self): | ||||
cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | ||||
model = SpaceForDialogIntentModel( | |||||
model = SpaceForDialogIntent( | |||||
model_dir=cache_path, | model_dir=cache_path, | ||||
text_field=preprocessor.text_field, | text_field=preprocessor.text_field, | ||||
config=preprocessor.config) | config=preprocessor.config) | ||||
@@ -37,6 +38,7 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||||
for my_pipeline, item in list(zip(pipelines, self.test_case)): | for my_pipeline, item in list(zip(pipelines, self.test_case)): | ||||
print(my_pipeline(item)) | print(my_pipeline(item)) | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
preprocessor = DialogIntentPredictionPreprocessor( | preprocessor = DialogIntentPredictionPreprocessor( | ||||
@@ -1,15 +1,13 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | |||||
import os.path as osp | |||||
import tempfile | |||||
import unittest | import unittest | ||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.models import Model | from modelscope.models import Model | ||||
from modelscope.models.nlp import SpaceForDialogModelingModel | |||||
from modelscope.models.nlp import SpaceForDialogModeling | |||||
from modelscope.pipelines import DialogModelingPipeline, pipeline | from modelscope.pipelines import DialogModelingPipeline, pipeline | ||||
from modelscope.preprocessors import DialogModelingPreprocessor | from modelscope.preprocessors import DialogModelingPreprocessor | ||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
from modelscope.utils.test_utils import test_level | |||||
class DialogModelingTest(unittest.TestCase): | class DialogModelingTest(unittest.TestCase): | ||||
@@ -91,13 +89,13 @@ class DialogModelingTest(unittest.TestCase): | |||||
} | } | ||||
} | } | ||||
@unittest.skip('test with snapshot_download') | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run(self): | def test_run(self): | ||||
cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
preprocessor = DialogModelingPreprocessor(model_dir=cache_path) | preprocessor = DialogModelingPreprocessor(model_dir=cache_path) | ||||
model = SpaceForDialogModelingModel( | |||||
model = SpaceForDialogModeling( | |||||
model_dir=cache_path, | model_dir=cache_path, | ||||
text_field=preprocessor.text_field, | text_field=preprocessor.text_field, | ||||
config=preprocessor.config) | config=preprocessor.config) | ||||
@@ -120,6 +118,7 @@ class DialogModelingTest(unittest.TestCase): | |||||
}) | }) | ||||
print('sys : {}'.format(result['sys'])) | print('sys : {}'.format(result['sys'])) | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) | preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) | ||||
@@ -29,7 +29,7 @@ class NLITest(unittest.TestCase): | |||||
f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | ||||
f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') | f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
tokenizer = NLIPreprocessor(model.model_dir) | tokenizer = NLIPreprocessor(model.model_dir) | ||||
@@ -42,7 +42,7 @@ class SentimentClassificationTest(unittest.TestCase): | |||||
preprocessor=tokenizer) | preprocessor=tokenizer) | ||||
print(pipeline_ins(input=self.sentence1)) | print(pipeline_ins(input=self.sentence1)) | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_model_name(self): | def test_run_with_model_name(self): | ||||
pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
task=Tasks.sentiment_classification, model=self.model_id) | task=Tasks.sentiment_classification, model=self.model_id) | ||||