diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 7e66f792..afba99a7 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -90,6 +90,7 @@ class Models(object): mglm = 'mglm' codegeex = 'codegeex' bloom = 'bloom' + unite = 'unite' # audio models sambert_hifigan = 'sambert-hifigan' @@ -275,6 +276,7 @@ class Pipelines(object): translation_en_to_ro = 'translation_en_to_ro' # keep it underscore translation_en_to_fr = 'translation_en_to_fr' # keep it underscore token_classification = 'token-classification' + translation_evaluation = 'translation-evaluation' # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' @@ -404,6 +406,7 @@ class Preprocessors(object): feature_extraction = 'feature-extraction' mglm_summarization = 'mglm-summarization' sentence_piece = 'sentence-piece' + translation_evaluation = 'translation-evaluation-preprocessor' # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 26205bcb..5d019de8 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: VecoForSequenceClassification, VecoForTokenClassification, VecoModel) from .bloom import BloomModel + from .unite import UniTEModel else: _import_structure = { 'backbones': ['SbertModel'], @@ -108,6 +109,7 @@ else: ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], 'gpt_neo': ['GPTNeoModel'], 'bloom': ['BloomModel'], + 'unite': ['UniTEModel'] } import sys diff --git a/modelscope/models/nlp/unite/__init__.py b/modelscope/models/nlp/unite/__init__.py new file mode 100644 index 00000000..06c2146e --- /dev/null +++ b/modelscope/models/nlp/unite/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration_unite import UniTEConfig + from .modeling_unite import UniTEForTranslationEvaluation +else: + _import_structure = { + 'configuration_unite': ['UniTEConfig'], + 'modeling_unite': ['UniTEForTranslationEvaluation'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/unite/configuration_unite.py b/modelscope/models/nlp/unite/configuration_unite.py new file mode 100644 index 00000000..81abd2db --- /dev/null +++ b/modelscope/models/nlp/unite/configuration_unite.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +"""UniTE model configuration""" + +from enum import Enum + +from modelscope.utils import logger as logging +from modelscope.utils.config import Config + +logger = logging.get_logger(__name__) + + +class EvaluationMode(Enum): + SRC = 'src' + REF = 'ref' + SRC_REF = 'src-ref' + + +class UniTEConfig(Config): + + def __init__(self, **kwargs): + super().__init__(**kwargs) diff --git a/modelscope/models/nlp/unite/modeling_unite.py b/modelscope/models/nlp/unite/modeling_unite.py new file mode 100644 index 00000000..b341b810 --- /dev/null +++ b/modelscope/models/nlp/unite/modeling_unite.py @@ -0,0 +1,400 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +"""PyTorch UniTE model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from packaging import version +from torch.nn import (Dropout, Linear, Module, Parameter, ParameterList, + Sequential) +from torch.nn.functional import softmax +from torch.nn.utils.rnn import pad_sequence +from transformers import XLMRobertaConfig, XLMRobertaModel +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + +__all__ = ['UniTEForTranslationEvaluation'] + + +def _layer_norm_all(tensor, mask_float): + broadcast_mask = mask_float.unsqueeze(dim=-1) + num_elements_not_masked = broadcast_mask.sum() * tensor.size(-1) + tensor_masked = tensor * broadcast_mask + + mean = tensor_masked.sum([-1, -2, -3], + keepdim=True) / num_elements_not_masked + variance = (((tensor_masked - mean) * broadcast_mask)**2).sum( + [-1, -2, -3], keepdim=True) / num_elements_not_masked + + return (tensor - mean) / torch.sqrt(variance + 1e-12) + + +class LayerwiseAttention(Module): + + def __init__( + self, + num_layers: int, + model_dim: int, + dropout: float = None, + ) -> None: + super(LayerwiseAttention, self).__init__() + self.num_layers = num_layers + self.model_dim = model_dim + self.dropout = dropout + + self.scalar_parameters = Parameter( + torch.zeros((num_layers, ), requires_grad=True)) + self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=True) + + if self.dropout: + dropout_mask = torch.zeros(len(self.scalar_parameters)) + dropout_fill = torch.empty(len( + self.scalar_parameters)).fill_(-1e20) + self.register_buffer('dropout_mask', dropout_mask) + self.register_buffer('dropout_fill', dropout_fill) + + def forward( + self, + tensors: List[torch.Tensor], # pylint: disable=arguments-differ + mask: torch.Tensor = None, + ) -> torch.Tensor: + tensors = torch.cat(list(x.unsqueeze(dim=0) for x in tensors), dim=0) + normed_weights = softmax( + self.scalar_parameters, dim=0).view(-1, 1, 1, 1) + + mask_float = mask.float() + weighted_sum = (normed_weights + * _layer_norm_all(tensors, mask_float)).sum(dim=0) + weighted_sum = weighted_sum[:, 0, :] + + return self.gamma * weighted_sum + + +class FeedForward(Module): + + def __init__( + self, + in_dim: int, + out_dim: int = 1, + hidden_sizes: List[int] = [3072, 768], + activations: str = 'Sigmoid', + final_activation: Optional[str] = None, + dropout: float = 0.1, + ) -> None: + """ + Feed Forward Neural Network. + + Args: + in_dim (:obj:`int`): + Number of input features. + out_dim (:obj:`int`, defaults to 1): + Number of output features. Default is 1 -- a single scalar. + hidden_sizes (:obj:`List[int]`, defaults to `[3072, 768]`): + List with hidden layer sizes. + activations (:obj:`str`, defaults to `Sigmoid`): + Name of the activation function to be used in the hidden layers. + final_activation (:obj:`str`, Optional, defaults to `None`): + Name of the final activation function if any. + dropout (:obj:`float`, defaults to 0.1): + Dropout ratio to be used in the hidden layers. + """ + super().__init__() + modules = [] + modules.append(Linear(in_dim, hidden_sizes[0])) + modules.append(self.build_activation(activations)) + modules.append(Dropout(dropout)) + + for i in range(1, len(hidden_sizes)): + modules.append(Linear(hidden_sizes[i - 1], hidden_sizes[i])) + modules.append(self.build_activation(activations)) + modules.append(Dropout(dropout)) + + modules.append(Linear(hidden_sizes[-1], int(out_dim))) + if final_activation is not None: + modules.append(self.build_activation(final_activation)) + + self.ff = Sequential(*modules) + + def build_activation(self, activation: str) -> Module: + return ACT2FN[activation] + + def forward(self, in_features: torch.Tensor) -> torch.Tensor: + return self.ff(in_features) + + +@MODELS.register_module(Tasks.translation_evaluation, module_name=Models.unite) +class UniTEForTranslationEvaluation(TorchModel): + + def __init__(self, + attention_probs_dropout_prob: float = 0.1, + bos_token_id: int = 0, + eos_token_id: int = 2, + pad_token_id: int = 1, + hidden_act: str = 'gelu', + hidden_dropout_prob: float = 0.1, + hidden_size: int = 1024, + initializer_range: float = 0.02, + intermediate_size: int = 4096, + layer_norm_eps: float = 1e-05, + max_position_embeddings: int = 512, + num_attention_heads: int = 16, + num_hidden_layers: int = 24, + type_vocab_size: int = 1, + use_cache: bool = True, + vocab_size: int = 250002, + mlp_hidden_sizes: List[int] = [3072, 1024], + mlp_act: str = 'tanh', + mlp_final_act: Optional[str] = None, + mlp_dropout: float = 0.1, + **kwargs): + r"""The UniTE Model which outputs the scalar to describe the corresponding + translation quality of hypothesis. The model architecture includes two + modules: a pre-trained language model (PLM) to derive representations, + and a multi-layer perceptron (MLP) to give predicted score. + + Args: + attention_probs_dropout_prob (:obj:`float`, defaults to 0.1): + The dropout ratio for attention weights inside PLM. + bos_token_id (:obj:`int`, defaults to 0): + The numeric id representing beginning-of-sentence symbol. + eos_token_id (:obj:`int`, defaults to 2): + The numeric id representing ending-of-sentence symbol. + pad_token_id (:obj:`int`, defaults to 1): + The numeric id representing padding symbol. + hidden_act (:obj:`str`, defaults to :obj:`"gelu"`): + Activation inside PLM. + hidden_dropout_prob (:obj:`float`, defaults to 0.1): + The dropout ratio for activation states inside PLM. + hidden_size (:obj:`int`, defaults to 1024): + The dimensionality of PLM. + initializer_range (:obj:`float`, defaults to 0.02): + The hyper-parameter for initializing PLM. + intermediate_size (:obj:`int`, defaults to 4096): + The dimensionality of PLM inside feed-forward block. + layer_norm_eps (:obj:`float`, defaults to 1e-5): + The value for setting epsilon to avoid zero-division inside + layer normalization. + max_position_embeddings: (:obj:`int`, defaults to 512): + The maximum value for identifying the length of input sequence. + num_attention_heads (:obj:`int`, defaults to 16): + The number of attention heads inside multi-head attention layer. + num_hidden_layers (:obj:`int`, defaults to 24): + The number of layers inside PLM. + type_vocab_size (:obj:`int`, defaults to 1): + The number of type embeddings. + use_cache (:obj:`bool`, defaults to :obj:`True`): + Whether to use cached buffer to initialize PLM. + vocab_size (:obj:`int`, defaults to 250002): + The size of vocabulary. + mlp_hidden_sizes (:obj:`List[int]`, defaults to `[3072, 1024]`): + The size of hidden states inside MLP. + mlp_act (:obj:`str`, defaults to :obj:`"tanh"`): + Activation inside MLP. + mlp_final_act (:obj:`str`, `optional`, defaults to :obj:`None`): + Activation at the end of MLP. + mlp_dropout (:obj:`float`, defaults to 0.1): + The dropout ratio for MLP. + """ + super().__init__(**kwargs) + + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.intermediate_size = intermediate_size + self.layer_norm_eps = layer_norm_eps + self.max_position_embeddings = max_position_embeddings + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.type_vocab_size = type_vocab_size + self.use_cache = use_cache + self.vocab_size = vocab_size + self.mlp_hidden_sizes = mlp_hidden_sizes + self.mlp_act = mlp_act + self.mlp_final_act = mlp_final_act + self.mlp_dropout = mlp_dropout + + self.encoder_config = XLMRobertaConfig( + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + layer_norm_eps=self.layer_norm_eps, + use_cache=self.use_cache) + + self.encoder = XLMRobertaModel( + self.encoder_config, add_pooling_layer=False) + + self.layerwise_attention = LayerwiseAttention( + num_layers=self.num_hidden_layers + 1, + model_dim=self.hidden_size, + dropout=self.mlp_dropout) + + self.estimator = FeedForward( + in_dim=self.hidden_size, + out_dim=1, + hidden_sizes=self.mlp_hidden_sizes, + activations=self.mlp_act, + final_activation=self.mlp_final_act, + dropout=self.mlp_dropout) + + return + + def forward(self, input_sentences: List[torch.Tensor]): + input_ids = self.combine_input_sentences(input_sentences) + attention_mask = input_ids.ne(self.pad_token_id).long() + outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True) + mix_states = self.layerwise_attention(outputs['hidden_states'], + attention_mask) + pred = self.estimator(mix_states) + return pred.squeeze(dim=-1) + + def load_checkpoint(self, path: str): + state_dict = torch.load(path) + self.load_state_dict(state_dict) + logger.info('Loading checkpoint parameters from %s' % path) + return + + def combine_input_sentences(self, input_sent_groups: List[torch.Tensor]): + for input_sent_group in input_sent_groups[1:]: + input_sent_group[:, 0] = self.eos_token_id + + if len(input_sent_groups) == 3: + cutted_sents = self.cut_long_sequences3(input_sent_groups) + else: + cutted_sents = self.cut_long_sequences2(input_sent_groups) + return cutted_sents + + @staticmethod + def cut_long_sequences2(all_input_concat: List[List[torch.Tensor]], + maximum_length: int = 512, + pad_idx: int = 1): + all_input_concat = list(zip(*all_input_concat)) + collected_tuples = list() + for tensor_tuple in all_input_concat: + all_lens = tuple(len(x) for x in tensor_tuple) + + if sum(all_lens) > maximum_length: + lengths = dict(enumerate(all_lens)) + lengths_sorted_idxes = list(x[0] for x in sorted( + lengths.items(), key=lambda d: d[1], reverse=True)) + + offset = ceil((sum(lengths.values()) - maximum_length) / 2) + + if min(all_lens) > (maximum_length + // 2) and min(all_lens) > offset: + lengths = dict((k, v - offset) for k, v in lengths.items()) + else: + lengths[lengths_sorted_idxes[ + 0]] = maximum_length - lengths[lengths_sorted_idxes[1]] + + new_lens = list(lengths[k] + for k in range(0, len(tensor_tuple))) + new_tensor_tuple = tuple( + x[:y] for x, y in zip(tensor_tuple, new_lens)) + for x, y in zip(new_tensor_tuple, tensor_tuple): + x[-1] = y[-1] + collected_tuples.append(new_tensor_tuple) + else: + collected_tuples.append(tensor_tuple) + + concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples) + all_input_concat_padded = pad_sequence( + concat_tensor, batch_first=True, padding_value=pad_idx) + + return all_input_concat_padded + + @staticmethod + def cut_long_sequences3(all_input_concat: List[List[torch.Tensor]], + maximum_length: int = 512, + pad_idx: int = 1): + all_input_concat = list(zip(*all_input_concat)) + collected_tuples = list() + for tensor_tuple in all_input_concat: + all_lens = tuple(len(x) for x in tensor_tuple) + + if sum(all_lens) > maximum_length: + lengths = dict(enumerate(all_lens)) + lengths_sorted_idxes = list(x[0] for x in sorted( + lengths.items(), key=lambda d: d[1], reverse=True)) + + offset = ceil((sum(lengths.values()) - maximum_length) / 3) + + if min(all_lens) > (maximum_length + // 3) and min(all_lens) > offset: + lengths = dict((k, v - offset) for k, v in lengths.items()) + else: + while sum(lengths.values()) > maximum_length: + if lengths[lengths_sorted_idxes[0]] > lengths[ + lengths_sorted_idxes[1]]: + offset = maximum_length - lengths[ + lengths_sorted_idxes[1]] - lengths[ + lengths_sorted_idxes[2]] + if offset > lengths[lengths_sorted_idxes[1]]: + lengths[lengths_sorted_idxes[0]] = offset + else: + lengths[lengths_sorted_idxes[0]] = lengths[ + lengths_sorted_idxes[1]] + elif lengths[lengths_sorted_idxes[0]] == lengths[ + lengths_sorted_idxes[1]] > lengths[ + lengths_sorted_idxes[2]]: + offset = (maximum_length + - lengths[lengths_sorted_idxes[2]]) // 2 + if offset > lengths[lengths_sorted_idxes[2]]: + lengths[lengths_sorted_idxes[0]] = lengths[ + lengths_sorted_idxes[1]] = offset + else: + lengths[lengths_sorted_idxes[0]] = lengths[ + lengths_sorted_idxes[1]] = lengths[ + lengths_sorted_idxes[2]] + else: + lengths[lengths_sorted_idxes[0]] = lengths[ + lengths_sorted_idxes[1]] = lengths[ + lengths_sorted_idxes[ + 2]] = maximum_length // 3 + + new_lens = list(lengths[k] for k in range(0, len(lengths))) + new_tensor_tuple = tuple( + x[:y] for x, y in zip(tensor_tuple, new_lens)) + + for x, y in zip(new_tensor_tuple, tensor_tuple): + x[-1] = y[-1] + collected_tuples.append(new_tensor_tuple) + else: + collected_tuples.append(tensor_tuple) + + concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples) + all_input_concat_padded = pad_sequence( + concat_tensor, batch_first=True, padding_value=pad_idx) + + return all_input_concat_padded diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index dbd1ec3c..94a8d035 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -801,6 +801,11 @@ TASK_OUTPUTS = { # ] # } Tasks.product_segmentation: [OutputKeys.MASKS], + + # { + # 'scores': [0.1, 0.2, 0.3, ...] + # } + Tasks.translation_evaluation: [OutputKeys.SCORES] } diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index 060049ef..0e44fcac 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -183,6 +183,11 @@ TASK_INPUTS = { 'query_set': InputType.LIST, 'support_set': InputType.LIST, }, + Tasks.translation_evaluation: { + 'hyp': InputType.LIST, + 'src': InputType.LIST, + 'ref': InputType.LIST, + }, # ============ audio tasks =================== Tasks.auto_speech_recognition: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index dac6011d..68d4f0b1 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -217,6 +217,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_swin-t_referring_video-object-segmentation'), Tasks.video_summarization: (Pipelines.video_summarization, 'damo/cv_googlenet_pgl-video-summarization'), + Tasks.translation_evaluation: + (Pipelines.translation_evaluation, + 'damo/nlp_unite_mup_translation_evaluation_multilingual_large'), } diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index eaff2144..707e2ac0 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline + from .translation_evaluation_pipeline import TranslationEvaluationPipeline else: _import_structure = { @@ -77,6 +78,7 @@ else: ['CodeGeeXCodeTranslationPipeline'], 'codegeex_code_generation_pipeline': ['CodeGeeXCodeGenerationPipeline'], + 'translation_evaluation_pipeline': ['TranslationEvaluationPipeline'], } import sys diff --git a/modelscope/pipelines/nlp/translation_evaluation_pipeline.py b/modelscope/pipelines/nlp/translation_evaluation_pipeline.py new file mode 100644 index 00000000..bc942342 --- /dev/null +++ b/modelscope/pipelines/nlp/translation_evaluation_pipeline.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.models.nlp.unite.configuration_unite import EvaluationMode +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import InputModel, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + TranslationEvaluationPreprocessor) +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + +__all__ = ['TranslationEvaluationPipeline'] + + +@PIPELINES.register_module( + Tasks.translation_evaluation, module_name=Pipelines.translation_evaluation) +class TranslationEvaluationPipeline(Pipeline): + + def __init__(self, + model: InputModel, + preprocessor: Optional[Preprocessor] = None, + eval_mode: EvaluationMode = EvaluationMode.SRC_REF, + **kwargs): + r"""Build a translation pipeline with a model dir or a model id in the model hub. + + Args: + model: A Model instance. + eval_mode: Evaluation mode, choosing one from `"EvaluationMode.SRC_REF"`, + `"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the + source/reference/source+reference can be presented during evaluation. + """ + super().__init__(model=model, preprocessor=preprocessor) + + self.eval_mode = eval_mode + self.checking_eval_mode() + + self.preprocessor = TranslationEvaluationPreprocessor( + self.model.model_dir, + self.eval_mode) if preprocessor is None else preprocessor + + self.model.load_checkpoint( + osp.join(self.model.model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + self.model.eval() + + return + + def checking_eval_mode(self): + if self.eval_mode == EvaluationMode.SRC: + logger.info('Evaluation mode: source-only') + elif self.eval_mode == EvaluationMode.REF: + logger.info('Evaluation mode: reference-only') + elif self.eval_mode == EvaluationMode.SRC_REF: + logger.info('Evaluation mode: source-reference-combined') + else: + raise ValueError( + 'Evaluation mode should be one choice among' + '\'EvaluationMode.SRC\', \'EvaluationMode.REF\', and' + '\'EvaluationMode.SRC_REF\'.') + + def change_eval_mode(self, + eval_mode: EvaluationMode = EvaluationMode.SRC_REF): + logger.info('Changing the evaluation mode.') + self.eval_mode = eval_mode + self.checking_eval_mode() + self.preprocessor.eval_mode = eval_mode + return + + def __call__(self, input_dict: Dict[str, Union[str, List[str]]], **kwargs): + r"""Implementation of __call__ function. + + Args: + input_dict: The formatted dict containing the inputted sentences. + An example of the formatted dict: + ``` + input_dict = { + 'hyp': [ + 'This is a sentence.', + 'This is another sentence.', + ], + 'src': [ + '这是个句子。', + '这是另一个句子。', + ], + 'ref': [ + 'It is a sentence.', + 'It is another sentence.', + ] + } + ``` + """ + return super().__call__(input=input_dict, **kwargs) + + def forward(self, + input_ids: List[torch.Tensor]) -> Dict[str, torch.Tensor]: + return self.model(input_ids) + + def postprocess(self, output: torch.Tensor) -> Dict[str, Any]: + result = {OutputKeys.SCORES: output.cpu().tolist()} + return result diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index b4adf935..79a2e489 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -33,7 +33,8 @@ if TYPE_CHECKING: DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor, TableQuestionAnsweringPreprocessor, NERPreprocessorViet, - NERPreprocessorThai, WordSegmentationPreprocessorThai) + NERPreprocessorThai, WordSegmentationPreprocessorThai, + TranslationEvaluationPreprocessor) from .video import ReadVideoData, MovieSceneSegmentationPreprocessor else: @@ -72,7 +73,8 @@ else: 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', 'DialogStateTrackingPreprocessor', 'ConversationalTextToSqlPreprocessor', - 'TableQuestionAnsweringPreprocessor' + 'TableQuestionAnsweringPreprocessor', + 'TranslationEvaluationPreprocessor' ], } diff --git a/modelscope/preprocessors/nlp/__init__.py b/modelscope/preprocessors/nlp/__init__.py index 8ee9a80c..c6fa2025 100644 --- a/modelscope/preprocessors/nlp/__init__.py +++ b/modelscope/preprocessors/nlp/__init__.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from .space_T_en import ConversationalTextToSqlPreprocessor from .space_T_cn import TableQuestionAnsweringPreprocessor from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor + from .translation_evaluation_preprocessor import TranslationEvaluationPreprocessor else: _import_structure = { 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], @@ -76,6 +77,8 @@ else: ], 'space_T_en': ['ConversationalTextToSqlPreprocessor'], 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], + 'translation_evaluation_preprocessor': + ['TranslationEvaluationPreprocessor'], } import sys diff --git a/modelscope/preprocessors/nlp/translation_evaluation_preprocessor.py b/modelscope/preprocessors/nlp/translation_evaluation_preprocessor.py new file mode 100644 index 00000000..0bf62cdc --- /dev/null +++ b/modelscope/preprocessors/nlp/translation_evaluation_preprocessor.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, List, Union + +from transformers import AutoTokenizer + +from modelscope.metainfo import Preprocessors +from modelscope.models.nlp.unite.configuration_unite import EvaluationMode +from modelscope.preprocessors import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .transformers_tokenizer import NLPTokenizer + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.translation_evaluation) +class TranslationEvaluationPreprocessor(Preprocessor): + r"""The tokenizer preprocessor used for translation evaluation. + """ + + def __init__(self, + model_dir: str, + eval_mode: EvaluationMode, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + r"""preprocess the data via the vocab file from the `model_dir` path + + Args: + model_dir: A Model instance. + eval_mode: Evaluation mode, choosing one from `"EvaluationMode.SRC_REF"`, + `"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the + source/reference/source+reference can be presented during evaluation. + """ + super().__init__(mode=mode) + self.tokenizer = NLPTokenizer( + model_dir=model_dir, use_fast=False, tokenize_kwargs=kwargs) + self.eval_mode = eval_mode + + return + + def __call__(self, input_dict: Dict[str, Any]) -> List[List[str]]: + if self.eval_mode == EvaluationMode.SRC and 'src' not in input_dict.keys( + ): + raise ValueError( + 'Source sentences are required for source-only evaluation mode.' + ) + if self.eval_mode == EvaluationMode.REF and 'ref' not in input_dict.keys( + ): + raise ValueError( + 'Reference sentences are required for reference-only evaluation mode.' + ) + if self.eval_mode == EvaluationMode.SRC_REF and ( + 'src' not in input_dict.keys() + or 'ref' not in input_dict.keys()): + raise ValueError( + 'Source and reference sentences are both required for source-reference-combined evaluation mode.' + ) + + if type(input_dict['hyp']) == str: + input_dict['hyp'] = [input_dict['hyp']] + if (self.eval_mode == EvaluationMode.SRC or self.eval_mode + == EvaluationMode.SRC_REF) and type(input_dict['src']) == str: + input_dict['src'] = [input_dict['src']] + if (self.eval_mode == EvaluationMode.REF or self.eval_mode + == EvaluationMode.SRC_REF) and type(input_dict['ref']) == str: + input_dict['ref'] = [input_dict['ref']] + + output_sents = [ + self.tokenizer( + input_dict['hyp'], return_tensors='pt', + padding=True)['input_ids'] + ] + if self.eval_mode == EvaluationMode.SRC or self.eval_mode == EvaluationMode.SRC_REF: + output_sents += [ + self.tokenizer( + input_dict['src'], return_tensors='pt', + padding=True)['input_ids'] + ] + if self.eval_mode == EvaluationMode.REF or self.eval_mode == EvaluationMode.SRC_REF: + output_sents += [ + self.tokenizer( + input_dict['ref'], return_tensors='pt', + padding=True)['input_ids'] + ] + + return output_sents diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 8376c971..4d585e1a 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -133,6 +133,7 @@ class NLPTasks(object): document_segmentation = 'document-segmentation' extractive_summarization = 'extractive-summarization' feature_extraction = 'feature-extraction' + translation_evaluation = 'translation-evaluation' class AudioTasks(object): diff --git a/tests/pipelines/test_translation_evaluation.py b/tests/pipelines/test_translation_evaluation.py new file mode 100644 index 00000000..0c73edca --- /dev/null +++ b/tests/pipelines/test_translation_evaluation.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.models.nlp.unite.configuration_unite import EvaluationMode +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.translation_evaluation + self.model_id_large = 'damo/nlp_unite_mup_translation_evaluation_multilingual_large' + self.model_id_base = 'damo/nlp_unite_mup_translation_evaluation_multilingual_base' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_unite_large(self): + input_dict = { + 'hyp': [ + 'This is a sentence.', + 'This is another sentence.', + ], + 'src': [ + '这是个句子。', + '这是另一个句子。', + ], + 'ref': [ + 'It is a sentence.', + 'It is another sentence.', + ] + } + + pipeline_ins = pipeline(self.task, model=self.model_id_large) + print(pipeline_ins(input_dict)) + + pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC) + print(pipeline_ins(input_dict)) + + pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF) + print(pipeline_ins(input_dict)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_unite_base(self): + input_dict = { + 'hyp': [ + 'This is a sentence.', + 'This is another sentence.', + ], + 'src': [ + '这是个句子。', + '这是另一个句子。', + ], + 'ref': [ + 'It is a sentence.', + 'It is another sentence.', + ] + } + + pipeline_ins = pipeline(self.task, model=self.model_id_base) + print(pipeline_ins(input_dict)) + + pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC) + print(pipeline_ins(input_dict)) + + pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF) + print(pipeline_ins(input_dict)) + + +if __name__ == '__main__': + unittest.main()