wanggui.hwg yingda.chen 3 years ago
parent
commit
e8608df930
15 changed files with 744 additions and 2 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/nlp/__init__.py
  3. +24
    -0
      modelscope/models/nlp/unite/__init__.py
  4. +21
    -0
      modelscope/models/nlp/unite/configuration_unite.py
  5. +400
    -0
      modelscope/models/nlp/unite/modeling_unite.py
  6. +5
    -0
      modelscope/outputs/outputs.py
  7. +5
    -0
      modelscope/pipeline_inputs.py
  8. +3
    -0
      modelscope/pipelines/builder.py
  9. +2
    -0
      modelscope/pipelines/nlp/__init__.py
  10. +111
    -0
      modelscope/pipelines/nlp/translation_evaluation_pipeline.py
  11. +4
    -2
      modelscope/preprocessors/__init__.py
  12. +3
    -0
      modelscope/preprocessors/nlp/__init__.py
  13. +87
    -0
      modelscope/preprocessors/nlp/translation_evaluation_preprocessor.py
  14. +1
    -0
      modelscope/utils/constant.py
  15. +73
    -0
      tests/pipelines/test_translation_evaluation.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -90,6 +90,7 @@ class Models(object):
mglm = 'mglm'
codegeex = 'codegeex'
bloom = 'bloom'
unite = 'unite'

# audio models
sambert_hifigan = 'sambert-hifigan'
@@ -275,6 +276,7 @@ class Pipelines(object):
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore
token_classification = 'token-classification'
translation_evaluation = 'translation-evaluation'

# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'
@@ -404,6 +406,7 @@ class Preprocessors(object):
feature_extraction = 'feature-extraction'
mglm_summarization = 'mglm-summarization'
sentence_piece = 'sentence-piece'
translation_evaluation = 'translation-evaluation-preprocessor'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -51,6 +51,7 @@ if TYPE_CHECKING:
VecoForSequenceClassification,
VecoForTokenClassification, VecoModel)
from .bloom import BloomModel
from .unite import UniTEModel
else:
_import_structure = {
'backbones': ['SbertModel'],
@@ -108,6 +109,7 @@ else:
['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'gpt_neo': ['GPTNeoModel'],
'bloom': ['BloomModel'],
'unite': ['UniTEModel']
}

import sys


+ 24
- 0
modelscope/models/nlp/unite/__init__.py View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .configuration_unite import UniTEConfig
from .modeling_unite import UniTEForTranslationEvaluation
else:
_import_structure = {
'configuration_unite': ['UniTEConfig'],
'modeling_unite': ['UniTEForTranslationEvaluation'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 21
- 0
modelscope/models/nlp/unite/configuration_unite.py View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""UniTE model configuration"""

from enum import Enum

from modelscope.utils import logger as logging
from modelscope.utils.config import Config

logger = logging.get_logger(__name__)


class EvaluationMode(Enum):
SRC = 'src'
REF = 'ref'
SRC_REF = 'src-ref'


class UniTEConfig(Config):

def __init__(self, **kwargs):
super().__init__(**kwargs)

+ 400
- 0
modelscope/models/nlp/unite/modeling_unite.py View File

@@ -0,0 +1,400 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""PyTorch UniTE model."""

import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from packaging import version
from torch.nn import (Dropout, Linear, Module, Parameter, ParameterList,
Sequential)
from torch.nn.functional import softmax
from torch.nn.utils.rnn import pad_sequence
from transformers import XLMRobertaConfig, XLMRobertaModel
from transformers.activations import ACT2FN

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger(__name__)

__all__ = ['UniTEForTranslationEvaluation']


def _layer_norm_all(tensor, mask_float):
broadcast_mask = mask_float.unsqueeze(dim=-1)
num_elements_not_masked = broadcast_mask.sum() * tensor.size(-1)
tensor_masked = tensor * broadcast_mask

mean = tensor_masked.sum([-1, -2, -3],
keepdim=True) / num_elements_not_masked
variance = (((tensor_masked - mean) * broadcast_mask)**2).sum(
[-1, -2, -3], keepdim=True) / num_elements_not_masked

return (tensor - mean) / torch.sqrt(variance + 1e-12)


class LayerwiseAttention(Module):

def __init__(
self,
num_layers: int,
model_dim: int,
dropout: float = None,
) -> None:
super(LayerwiseAttention, self).__init__()
self.num_layers = num_layers
self.model_dim = model_dim
self.dropout = dropout

self.scalar_parameters = Parameter(
torch.zeros((num_layers, ), requires_grad=True))
self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=True)

if self.dropout:
dropout_mask = torch.zeros(len(self.scalar_parameters))
dropout_fill = torch.empty(len(
self.scalar_parameters)).fill_(-1e20)
self.register_buffer('dropout_mask', dropout_mask)
self.register_buffer('dropout_fill', dropout_fill)

def forward(
self,
tensors: List[torch.Tensor], # pylint: disable=arguments-differ
mask: torch.Tensor = None,
) -> torch.Tensor:
tensors = torch.cat(list(x.unsqueeze(dim=0) for x in tensors), dim=0)
normed_weights = softmax(
self.scalar_parameters, dim=0).view(-1, 1, 1, 1)

mask_float = mask.float()
weighted_sum = (normed_weights
* _layer_norm_all(tensors, mask_float)).sum(dim=0)
weighted_sum = weighted_sum[:, 0, :]

return self.gamma * weighted_sum


class FeedForward(Module):

def __init__(
self,
in_dim: int,
out_dim: int = 1,
hidden_sizes: List[int] = [3072, 768],
activations: str = 'Sigmoid',
final_activation: Optional[str] = None,
dropout: float = 0.1,
) -> None:
"""
Feed Forward Neural Network.

Args:
in_dim (:obj:`int`):
Number of input features.
out_dim (:obj:`int`, defaults to 1):
Number of output features. Default is 1 -- a single scalar.
hidden_sizes (:obj:`List[int]`, defaults to `[3072, 768]`):
List with hidden layer sizes.
activations (:obj:`str`, defaults to `Sigmoid`):
Name of the activation function to be used in the hidden layers.
final_activation (:obj:`str`, Optional, defaults to `None`):
Name of the final activation function if any.
dropout (:obj:`float`, defaults to 0.1):
Dropout ratio to be used in the hidden layers.
"""
super().__init__()
modules = []
modules.append(Linear(in_dim, hidden_sizes[0]))
modules.append(self.build_activation(activations))
modules.append(Dropout(dropout))

for i in range(1, len(hidden_sizes)):
modules.append(Linear(hidden_sizes[i - 1], hidden_sizes[i]))
modules.append(self.build_activation(activations))
modules.append(Dropout(dropout))

modules.append(Linear(hidden_sizes[-1], int(out_dim)))
if final_activation is not None:
modules.append(self.build_activation(final_activation))

self.ff = Sequential(*modules)

def build_activation(self, activation: str) -> Module:
return ACT2FN[activation]

def forward(self, in_features: torch.Tensor) -> torch.Tensor:
return self.ff(in_features)


@MODELS.register_module(Tasks.translation_evaluation, module_name=Models.unite)
class UniTEForTranslationEvaluation(TorchModel):

def __init__(self,
attention_probs_dropout_prob: float = 0.1,
bos_token_id: int = 0,
eos_token_id: int = 2,
pad_token_id: int = 1,
hidden_act: str = 'gelu',
hidden_dropout_prob: float = 0.1,
hidden_size: int = 1024,
initializer_range: float = 0.02,
intermediate_size: int = 4096,
layer_norm_eps: float = 1e-05,
max_position_embeddings: int = 512,
num_attention_heads: int = 16,
num_hidden_layers: int = 24,
type_vocab_size: int = 1,
use_cache: bool = True,
vocab_size: int = 250002,
mlp_hidden_sizes: List[int] = [3072, 1024],
mlp_act: str = 'tanh',
mlp_final_act: Optional[str] = None,
mlp_dropout: float = 0.1,
**kwargs):
r"""The UniTE Model which outputs the scalar to describe the corresponding
translation quality of hypothesis. The model architecture includes two
modules: a pre-trained language model (PLM) to derive representations,
and a multi-layer perceptron (MLP) to give predicted score.

Args:
attention_probs_dropout_prob (:obj:`float`, defaults to 0.1):
The dropout ratio for attention weights inside PLM.
bos_token_id (:obj:`int`, defaults to 0):
The numeric id representing beginning-of-sentence symbol.
eos_token_id (:obj:`int`, defaults to 2):
The numeric id representing ending-of-sentence symbol.
pad_token_id (:obj:`int`, defaults to 1):
The numeric id representing padding symbol.
hidden_act (:obj:`str`, defaults to :obj:`"gelu"`):
Activation inside PLM.
hidden_dropout_prob (:obj:`float`, defaults to 0.1):
The dropout ratio for activation states inside PLM.
hidden_size (:obj:`int`, defaults to 1024):
The dimensionality of PLM.
initializer_range (:obj:`float`, defaults to 0.02):
The hyper-parameter for initializing PLM.
intermediate_size (:obj:`int`, defaults to 4096):
The dimensionality of PLM inside feed-forward block.
layer_norm_eps (:obj:`float`, defaults to 1e-5):
The value for setting epsilon to avoid zero-division inside
layer normalization.
max_position_embeddings: (:obj:`int`, defaults to 512):
The maximum value for identifying the length of input sequence.
num_attention_heads (:obj:`int`, defaults to 16):
The number of attention heads inside multi-head attention layer.
num_hidden_layers (:obj:`int`, defaults to 24):
The number of layers inside PLM.
type_vocab_size (:obj:`int`, defaults to 1):
The number of type embeddings.
use_cache (:obj:`bool`, defaults to :obj:`True`):
Whether to use cached buffer to initialize PLM.
vocab_size (:obj:`int`, defaults to 250002):
The size of vocabulary.
mlp_hidden_sizes (:obj:`List[int]`, defaults to `[3072, 1024]`):
The size of hidden states inside MLP.
mlp_act (:obj:`str`, defaults to :obj:`"tanh"`):
Activation inside MLP.
mlp_final_act (:obj:`str`, `optional`, defaults to :obj:`None`):
Activation at the end of MLP.
mlp_dropout (:obj:`float`, defaults to 0.1):
The dropout ratio for MLP.
"""
super().__init__(**kwargs)

self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.hidden_size = hidden_size
self.initializer_range = initializer_range
self.intermediate_size = intermediate_size
self.layer_norm_eps = layer_norm_eps
self.max_position_embeddings = max_position_embeddings
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.type_vocab_size = type_vocab_size
self.use_cache = use_cache
self.vocab_size = vocab_size
self.mlp_hidden_sizes = mlp_hidden_sizes
self.mlp_act = mlp_act
self.mlp_final_act = mlp_final_act
self.mlp_dropout = mlp_dropout

self.encoder_config = XLMRobertaConfig(
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
layer_norm_eps=self.layer_norm_eps,
use_cache=self.use_cache)

self.encoder = XLMRobertaModel(
self.encoder_config, add_pooling_layer=False)

self.layerwise_attention = LayerwiseAttention(
num_layers=self.num_hidden_layers + 1,
model_dim=self.hidden_size,
dropout=self.mlp_dropout)

self.estimator = FeedForward(
in_dim=self.hidden_size,
out_dim=1,
hidden_sizes=self.mlp_hidden_sizes,
activations=self.mlp_act,
final_activation=self.mlp_final_act,
dropout=self.mlp_dropout)

return

def forward(self, input_sentences: List[torch.Tensor]):
input_ids = self.combine_input_sentences(input_sentences)
attention_mask = input_ids.ne(self.pad_token_id).long()
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True)
mix_states = self.layerwise_attention(outputs['hidden_states'],
attention_mask)
pred = self.estimator(mix_states)
return pred.squeeze(dim=-1)

def load_checkpoint(self, path: str):
state_dict = torch.load(path)
self.load_state_dict(state_dict)
logger.info('Loading checkpoint parameters from %s' % path)
return

def combine_input_sentences(self, input_sent_groups: List[torch.Tensor]):
for input_sent_group in input_sent_groups[1:]:
input_sent_group[:, 0] = self.eos_token_id

if len(input_sent_groups) == 3:
cutted_sents = self.cut_long_sequences3(input_sent_groups)
else:
cutted_sents = self.cut_long_sequences2(input_sent_groups)
return cutted_sents

@staticmethod
def cut_long_sequences2(all_input_concat: List[List[torch.Tensor]],
maximum_length: int = 512,
pad_idx: int = 1):
all_input_concat = list(zip(*all_input_concat))
collected_tuples = list()
for tensor_tuple in all_input_concat:
all_lens = tuple(len(x) for x in tensor_tuple)

if sum(all_lens) > maximum_length:
lengths = dict(enumerate(all_lens))
lengths_sorted_idxes = list(x[0] for x in sorted(
lengths.items(), key=lambda d: d[1], reverse=True))

offset = ceil((sum(lengths.values()) - maximum_length) / 2)

if min(all_lens) > (maximum_length
// 2) and min(all_lens) > offset:
lengths = dict((k, v - offset) for k, v in lengths.items())
else:
lengths[lengths_sorted_idxes[
0]] = maximum_length - lengths[lengths_sorted_idxes[1]]

new_lens = list(lengths[k]
for k in range(0, len(tensor_tuple)))
new_tensor_tuple = tuple(
x[:y] for x, y in zip(tensor_tuple, new_lens))
for x, y in zip(new_tensor_tuple, tensor_tuple):
x[-1] = y[-1]
collected_tuples.append(new_tensor_tuple)
else:
collected_tuples.append(tensor_tuple)

concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples)
all_input_concat_padded = pad_sequence(
concat_tensor, batch_first=True, padding_value=pad_idx)

return all_input_concat_padded

@staticmethod
def cut_long_sequences3(all_input_concat: List[List[torch.Tensor]],
maximum_length: int = 512,
pad_idx: int = 1):
all_input_concat = list(zip(*all_input_concat))
collected_tuples = list()
for tensor_tuple in all_input_concat:
all_lens = tuple(len(x) for x in tensor_tuple)

if sum(all_lens) > maximum_length:
lengths = dict(enumerate(all_lens))
lengths_sorted_idxes = list(x[0] for x in sorted(
lengths.items(), key=lambda d: d[1], reverse=True))

offset = ceil((sum(lengths.values()) - maximum_length) / 3)

if min(all_lens) > (maximum_length
// 3) and min(all_lens) > offset:
lengths = dict((k, v - offset) for k, v in lengths.items())
else:
while sum(lengths.values()) > maximum_length:
if lengths[lengths_sorted_idxes[0]] > lengths[
lengths_sorted_idxes[1]]:
offset = maximum_length - lengths[
lengths_sorted_idxes[1]] - lengths[
lengths_sorted_idxes[2]]
if offset > lengths[lengths_sorted_idxes[1]]:
lengths[lengths_sorted_idxes[0]] = offset
else:
lengths[lengths_sorted_idxes[0]] = lengths[
lengths_sorted_idxes[1]]
elif lengths[lengths_sorted_idxes[0]] == lengths[
lengths_sorted_idxes[1]] > lengths[
lengths_sorted_idxes[2]]:
offset = (maximum_length
- lengths[lengths_sorted_idxes[2]]) // 2
if offset > lengths[lengths_sorted_idxes[2]]:
lengths[lengths_sorted_idxes[0]] = lengths[
lengths_sorted_idxes[1]] = offset
else:
lengths[lengths_sorted_idxes[0]] = lengths[
lengths_sorted_idxes[1]] = lengths[
lengths_sorted_idxes[2]]
else:
lengths[lengths_sorted_idxes[0]] = lengths[
lengths_sorted_idxes[1]] = lengths[
lengths_sorted_idxes[
2]] = maximum_length // 3

new_lens = list(lengths[k] for k in range(0, len(lengths)))
new_tensor_tuple = tuple(
x[:y] for x, y in zip(tensor_tuple, new_lens))

for x, y in zip(new_tensor_tuple, tensor_tuple):
x[-1] = y[-1]
collected_tuples.append(new_tensor_tuple)
else:
collected_tuples.append(tensor_tuple)

concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples)
all_input_concat_padded = pad_sequence(
concat_tensor, batch_first=True, padding_value=pad_idx)

return all_input_concat_padded

+ 5
- 0
modelscope/outputs/outputs.py View File

@@ -801,6 +801,11 @@ TASK_OUTPUTS = {
# ]
# }
Tasks.product_segmentation: [OutputKeys.MASKS],

# {
# 'scores': [0.1, 0.2, 0.3, ...]
# }
Tasks.translation_evaluation: [OutputKeys.SCORES]
}




+ 5
- 0
modelscope/pipeline_inputs.py View File

@@ -183,6 +183,11 @@ TASK_INPUTS = {
'query_set': InputType.LIST,
'support_set': InputType.LIST,
},
Tasks.translation_evaluation: {
'hyp': InputType.LIST,
'src': InputType.LIST,
'ref': InputType.LIST,
},

# ============ audio tasks ===================
Tasks.auto_speech_recognition:


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -217,6 +217,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_swin-t_referring_video-object-segmentation'),
Tasks.video_summarization: (Pipelines.video_summarization,
'damo/cv_googlenet_pgl-video-summarization'),
Tasks.translation_evaluation:
(Pipelines.translation_evaluation,
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),
}




+ 2
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -32,6 +32,7 @@ if TYPE_CHECKING:
from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline
from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline
from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline
from .translation_evaluation_pipeline import TranslationEvaluationPipeline

else:
_import_structure = {
@@ -77,6 +78,7 @@ else:
['CodeGeeXCodeTranslationPipeline'],
'codegeex_code_generation_pipeline':
['CodeGeeXCodeGenerationPipeline'],
'translation_evaluation_pipeline': ['TranslationEvaluationPipeline'],
}

import sys


+ 111
- 0
modelscope/pipelines/nlp/translation_evaluation_pipeline.py View File

@@ -0,0 +1,111 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.models.nlp.unite.configuration_unite import EvaluationMode
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import InputModel, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import (Preprocessor,
TranslationEvaluationPreprocessor)
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger(__name__)

__all__ = ['TranslationEvaluationPipeline']


@PIPELINES.register_module(
Tasks.translation_evaluation, module_name=Pipelines.translation_evaluation)
class TranslationEvaluationPipeline(Pipeline):

def __init__(self,
model: InputModel,
preprocessor: Optional[Preprocessor] = None,
eval_mode: EvaluationMode = EvaluationMode.SRC_REF,
**kwargs):
r"""Build a translation pipeline with a model dir or a model id in the model hub.

Args:
model: A Model instance.
eval_mode: Evaluation mode, choosing one from `"EvaluationMode.SRC_REF"`,
`"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the
source/reference/source+reference can be presented during evaluation.
"""
super().__init__(model=model, preprocessor=preprocessor)

self.eval_mode = eval_mode
self.checking_eval_mode()

self.preprocessor = TranslationEvaluationPreprocessor(
self.model.model_dir,
self.eval_mode) if preprocessor is None else preprocessor

self.model.load_checkpoint(
osp.join(self.model.model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
self.model.eval()

return

def checking_eval_mode(self):
if self.eval_mode == EvaluationMode.SRC:
logger.info('Evaluation mode: source-only')
elif self.eval_mode == EvaluationMode.REF:
logger.info('Evaluation mode: reference-only')
elif self.eval_mode == EvaluationMode.SRC_REF:
logger.info('Evaluation mode: source-reference-combined')
else:
raise ValueError(
'Evaluation mode should be one choice among'
'\'EvaluationMode.SRC\', \'EvaluationMode.REF\', and'
'\'EvaluationMode.SRC_REF\'.')

def change_eval_mode(self,
eval_mode: EvaluationMode = EvaluationMode.SRC_REF):
logger.info('Changing the evaluation mode.')
self.eval_mode = eval_mode
self.checking_eval_mode()
self.preprocessor.eval_mode = eval_mode
return

def __call__(self, input_dict: Dict[str, Union[str, List[str]]], **kwargs):
r"""Implementation of __call__ function.

Args:
input_dict: The formatted dict containing the inputted sentences.
An example of the formatted dict:
```
input_dict = {
'hyp': [
'This is a sentence.',
'This is another sentence.',
],
'src': [
'这是个句子。',
'这是另一个句子。',
],
'ref': [
'It is a sentence.',
'It is another sentence.',
]
}
```
"""
return super().__call__(input=input_dict, **kwargs)

def forward(self,
input_ids: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.model(input_ids)

def postprocess(self, output: torch.Tensor) -> Dict[str, Any]:
result = {OutputKeys.SCORES: output.cpu().tolist()}
return result

+ 4
- 2
modelscope/preprocessors/__init__.py View File

@@ -33,7 +33,8 @@ if TYPE_CHECKING:
DialogIntentPredictionPreprocessor, DialogModelingPreprocessor,
DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor,
TableQuestionAnsweringPreprocessor, NERPreprocessorViet,
NERPreprocessorThai, WordSegmentationPreprocessorThai)
NERPreprocessorThai, WordSegmentationPreprocessorThai,
TranslationEvaluationPreprocessor)
from .video import ReadVideoData, MovieSceneSegmentationPreprocessor

else:
@@ -72,7 +73,8 @@ else:
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',
'DialogStateTrackingPreprocessor',
'ConversationalTextToSqlPreprocessor',
'TableQuestionAnsweringPreprocessor'
'TableQuestionAnsweringPreprocessor',
'TranslationEvaluationPreprocessor'
],
}



+ 3
- 0
modelscope/preprocessors/nlp/__init__.py View File

@@ -28,6 +28,7 @@ if TYPE_CHECKING:
from .space_T_en import ConversationalTextToSqlPreprocessor
from .space_T_cn import TableQuestionAnsweringPreprocessor
from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor
from .translation_evaluation_preprocessor import TranslationEvaluationPreprocessor
else:
_import_structure = {
'sentence_piece_preprocessor': ['SentencePiecePreprocessor'],
@@ -76,6 +77,8 @@ else:
],
'space_T_en': ['ConversationalTextToSqlPreprocessor'],
'space_T_cn': ['TableQuestionAnsweringPreprocessor'],
'translation_evaluation_preprocessor':
['TranslationEvaluationPreprocessor'],
}

import sys


+ 87
- 0
modelscope/preprocessors/nlp/translation_evaluation_preprocessor.py View File

@@ -0,0 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict, List, Union

from transformers import AutoTokenizer

from modelscope.metainfo import Preprocessors
from modelscope.models.nlp.unite.configuration_unite import EvaluationMode
from modelscope.preprocessors import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.constant import Fields, ModeKeys
from .transformers_tokenizer import NLPTokenizer


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.translation_evaluation)
class TranslationEvaluationPreprocessor(Preprocessor):
r"""The tokenizer preprocessor used for translation evaluation.
"""

def __init__(self,
model_dir: str,
eval_mode: EvaluationMode,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
r"""preprocess the data via the vocab file from the `model_dir` path

Args:
model_dir: A Model instance.
eval_mode: Evaluation mode, choosing one from `"EvaluationMode.SRC_REF"`,
`"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the
source/reference/source+reference can be presented during evaluation.
"""
super().__init__(mode=mode)
self.tokenizer = NLPTokenizer(
model_dir=model_dir, use_fast=False, tokenize_kwargs=kwargs)
self.eval_mode = eval_mode

return

def __call__(self, input_dict: Dict[str, Any]) -> List[List[str]]:
if self.eval_mode == EvaluationMode.SRC and 'src' not in input_dict.keys(
):
raise ValueError(
'Source sentences are required for source-only evaluation mode.'
)
if self.eval_mode == EvaluationMode.REF and 'ref' not in input_dict.keys(
):
raise ValueError(
'Reference sentences are required for reference-only evaluation mode.'
)
if self.eval_mode == EvaluationMode.SRC_REF and (
'src' not in input_dict.keys()
or 'ref' not in input_dict.keys()):
raise ValueError(
'Source and reference sentences are both required for source-reference-combined evaluation mode.'
)

if type(input_dict['hyp']) == str:
input_dict['hyp'] = [input_dict['hyp']]
if (self.eval_mode == EvaluationMode.SRC or self.eval_mode
== EvaluationMode.SRC_REF) and type(input_dict['src']) == str:
input_dict['src'] = [input_dict['src']]
if (self.eval_mode == EvaluationMode.REF or self.eval_mode
== EvaluationMode.SRC_REF) and type(input_dict['ref']) == str:
input_dict['ref'] = [input_dict['ref']]

output_sents = [
self.tokenizer(
input_dict['hyp'], return_tensors='pt',
padding=True)['input_ids']
]
if self.eval_mode == EvaluationMode.SRC or self.eval_mode == EvaluationMode.SRC_REF:
output_sents += [
self.tokenizer(
input_dict['src'], return_tensors='pt',
padding=True)['input_ids']
]
if self.eval_mode == EvaluationMode.REF or self.eval_mode == EvaluationMode.SRC_REF:
output_sents += [
self.tokenizer(
input_dict['ref'], return_tensors='pt',
padding=True)['input_ids']
]

return output_sents

+ 1
- 0
modelscope/utils/constant.py View File

@@ -133,6 +133,7 @@ class NLPTasks(object):
document_segmentation = 'document-segmentation'
extractive_summarization = 'extractive-summarization'
feature_extraction = 'feature-extraction'
translation_evaluation = 'translation-evaluation'


class AudioTasks(object):


+ 73
- 0
tests/pipelines/test_translation_evaluation.py View File

@@ -0,0 +1,73 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from modelscope.models.nlp.unite.configuration_unite import EvaluationMode
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.translation_evaluation
self.model_id_large = 'damo/nlp_unite_mup_translation_evaluation_multilingual_large'
self.model_id_base = 'damo/nlp_unite_mup_translation_evaluation_multilingual_base'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name_for_unite_large(self):
input_dict = {
'hyp': [
'This is a sentence.',
'This is another sentence.',
],
'src': [
'这是个句子。',
'这是另一个句子。',
],
'ref': [
'It is a sentence.',
'It is another sentence.',
]
}

pipeline_ins = pipeline(self.task, model=self.model_id_large)
print(pipeline_ins(input_dict))

pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC)
print(pipeline_ins(input_dict))

pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF)
print(pipeline_ins(input_dict))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name_for_unite_base(self):
input_dict = {
'hyp': [
'This is a sentence.',
'This is another sentence.',
],
'src': [
'这是个句子。',
'这是另一个句子。',
],
'ref': [
'It is a sentence.',
'It is another sentence.',
]
}

pipeline_ins = pipeline(self.task, model=self.model_id_base)
print(pipeline_ins(input_dict))

pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC)
print(pipeline_ins(input_dict))

pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF)
print(pipeline_ins(input_dict))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save