Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10909489master^2
| @@ -90,6 +90,7 @@ class Models(object): | |||
| mglm = 'mglm' | |||
| codegeex = 'codegeex' | |||
| bloom = 'bloom' | |||
| unite = 'unite' | |||
| # audio models | |||
| sambert_hifigan = 'sambert-hifigan' | |||
| @@ -275,6 +276,7 @@ class Pipelines(object): | |||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
| token_classification = 'token-classification' | |||
| translation_evaluation = 'translation-evaluation' | |||
| # audio tasks | |||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
| @@ -404,6 +406,7 @@ class Preprocessors(object): | |||
| feature_extraction = 'feature-extraction' | |||
| mglm_summarization = 'mglm-summarization' | |||
| sentence_piece = 'sentence-piece' | |||
| translation_evaluation = 'translation-evaluation-preprocessor' | |||
| # audio preprocessor | |||
| linear_aec_fbank = 'linear-aec-fbank' | |||
| @@ -51,6 +51,7 @@ if TYPE_CHECKING: | |||
| VecoForSequenceClassification, | |||
| VecoForTokenClassification, VecoModel) | |||
| from .bloom import BloomModel | |||
| from .unite import UniTEModel | |||
| else: | |||
| _import_structure = { | |||
| 'backbones': ['SbertModel'], | |||
| @@ -108,6 +109,7 @@ else: | |||
| ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | |||
| 'gpt_neo': ['GPTNeoModel'], | |||
| 'bloom': ['BloomModel'], | |||
| 'unite': ['UniTEModel'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .configuration_unite import UniTEConfig | |||
| from .modeling_unite import UniTEForTranslationEvaluation | |||
| else: | |||
| _import_structure = { | |||
| 'configuration_unite': ['UniTEConfig'], | |||
| 'modeling_unite': ['UniTEForTranslationEvaluation'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,21 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| """UniTE model configuration""" | |||
| from enum import Enum | |||
| from modelscope.utils import logger as logging | |||
| from modelscope.utils.config import Config | |||
| logger = logging.get_logger(__name__) | |||
| class EvaluationMode(Enum): | |||
| SRC = 'src' | |||
| REF = 'ref' | |||
| SRC_REF = 'src-ref' | |||
| class UniTEConfig(Config): | |||
| def __init__(self, **kwargs): | |||
| super().__init__(**kwargs) | |||
| @@ -0,0 +1,400 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| """PyTorch UniTE model.""" | |||
| import math | |||
| import warnings | |||
| from dataclasses import dataclass | |||
| from typing import Dict, List, Optional, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| import torch.utils.checkpoint | |||
| from packaging import version | |||
| from torch.nn import (Dropout, Linear, Module, Parameter, ParameterList, | |||
| Sequential) | |||
| from torch.nn.functional import softmax | |||
| from torch.nn.utils.rnn import pad_sequence | |||
| from transformers import XLMRobertaConfig, XLMRobertaModel | |||
| from transformers.activations import ACT2FN | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger(__name__) | |||
| __all__ = ['UniTEForTranslationEvaluation'] | |||
| def _layer_norm_all(tensor, mask_float): | |||
| broadcast_mask = mask_float.unsqueeze(dim=-1) | |||
| num_elements_not_masked = broadcast_mask.sum() * tensor.size(-1) | |||
| tensor_masked = tensor * broadcast_mask | |||
| mean = tensor_masked.sum([-1, -2, -3], | |||
| keepdim=True) / num_elements_not_masked | |||
| variance = (((tensor_masked - mean) * broadcast_mask)**2).sum( | |||
| [-1, -2, -3], keepdim=True) / num_elements_not_masked | |||
| return (tensor - mean) / torch.sqrt(variance + 1e-12) | |||
| class LayerwiseAttention(Module): | |||
| def __init__( | |||
| self, | |||
| num_layers: int, | |||
| model_dim: int, | |||
| dropout: float = None, | |||
| ) -> None: | |||
| super(LayerwiseAttention, self).__init__() | |||
| self.num_layers = num_layers | |||
| self.model_dim = model_dim | |||
| self.dropout = dropout | |||
| self.scalar_parameters = Parameter( | |||
| torch.zeros((num_layers, ), requires_grad=True)) | |||
| self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=True) | |||
| if self.dropout: | |||
| dropout_mask = torch.zeros(len(self.scalar_parameters)) | |||
| dropout_fill = torch.empty(len( | |||
| self.scalar_parameters)).fill_(-1e20) | |||
| self.register_buffer('dropout_mask', dropout_mask) | |||
| self.register_buffer('dropout_fill', dropout_fill) | |||
| def forward( | |||
| self, | |||
| tensors: List[torch.Tensor], # pylint: disable=arguments-differ | |||
| mask: torch.Tensor = None, | |||
| ) -> torch.Tensor: | |||
| tensors = torch.cat(list(x.unsqueeze(dim=0) for x in tensors), dim=0) | |||
| normed_weights = softmax( | |||
| self.scalar_parameters, dim=0).view(-1, 1, 1, 1) | |||
| mask_float = mask.float() | |||
| weighted_sum = (normed_weights | |||
| * _layer_norm_all(tensors, mask_float)).sum(dim=0) | |||
| weighted_sum = weighted_sum[:, 0, :] | |||
| return self.gamma * weighted_sum | |||
| class FeedForward(Module): | |||
| def __init__( | |||
| self, | |||
| in_dim: int, | |||
| out_dim: int = 1, | |||
| hidden_sizes: List[int] = [3072, 768], | |||
| activations: str = 'Sigmoid', | |||
| final_activation: Optional[str] = None, | |||
| dropout: float = 0.1, | |||
| ) -> None: | |||
| """ | |||
| Feed Forward Neural Network. | |||
| Args: | |||
| in_dim (:obj:`int`): | |||
| Number of input features. | |||
| out_dim (:obj:`int`, defaults to 1): | |||
| Number of output features. Default is 1 -- a single scalar. | |||
| hidden_sizes (:obj:`List[int]`, defaults to `[3072, 768]`): | |||
| List with hidden layer sizes. | |||
| activations (:obj:`str`, defaults to `Sigmoid`): | |||
| Name of the activation function to be used in the hidden layers. | |||
| final_activation (:obj:`str`, Optional, defaults to `None`): | |||
| Name of the final activation function if any. | |||
| dropout (:obj:`float`, defaults to 0.1): | |||
| Dropout ratio to be used in the hidden layers. | |||
| """ | |||
| super().__init__() | |||
| modules = [] | |||
| modules.append(Linear(in_dim, hidden_sizes[0])) | |||
| modules.append(self.build_activation(activations)) | |||
| modules.append(Dropout(dropout)) | |||
| for i in range(1, len(hidden_sizes)): | |||
| modules.append(Linear(hidden_sizes[i - 1], hidden_sizes[i])) | |||
| modules.append(self.build_activation(activations)) | |||
| modules.append(Dropout(dropout)) | |||
| modules.append(Linear(hidden_sizes[-1], int(out_dim))) | |||
| if final_activation is not None: | |||
| modules.append(self.build_activation(final_activation)) | |||
| self.ff = Sequential(*modules) | |||
| def build_activation(self, activation: str) -> Module: | |||
| return ACT2FN[activation] | |||
| def forward(self, in_features: torch.Tensor) -> torch.Tensor: | |||
| return self.ff(in_features) | |||
| @MODELS.register_module(Tasks.translation_evaluation, module_name=Models.unite) | |||
| class UniTEForTranslationEvaluation(TorchModel): | |||
| def __init__(self, | |||
| attention_probs_dropout_prob: float = 0.1, | |||
| bos_token_id: int = 0, | |||
| eos_token_id: int = 2, | |||
| pad_token_id: int = 1, | |||
| hidden_act: str = 'gelu', | |||
| hidden_dropout_prob: float = 0.1, | |||
| hidden_size: int = 1024, | |||
| initializer_range: float = 0.02, | |||
| intermediate_size: int = 4096, | |||
| layer_norm_eps: float = 1e-05, | |||
| max_position_embeddings: int = 512, | |||
| num_attention_heads: int = 16, | |||
| num_hidden_layers: int = 24, | |||
| type_vocab_size: int = 1, | |||
| use_cache: bool = True, | |||
| vocab_size: int = 250002, | |||
| mlp_hidden_sizes: List[int] = [3072, 1024], | |||
| mlp_act: str = 'tanh', | |||
| mlp_final_act: Optional[str] = None, | |||
| mlp_dropout: float = 0.1, | |||
| **kwargs): | |||
| r"""The UniTE Model which outputs the scalar to describe the corresponding | |||
| translation quality of hypothesis. The model architecture includes two | |||
| modules: a pre-trained language model (PLM) to derive representations, | |||
| and a multi-layer perceptron (MLP) to give predicted score. | |||
| Args: | |||
| attention_probs_dropout_prob (:obj:`float`, defaults to 0.1): | |||
| The dropout ratio for attention weights inside PLM. | |||
| bos_token_id (:obj:`int`, defaults to 0): | |||
| The numeric id representing beginning-of-sentence symbol. | |||
| eos_token_id (:obj:`int`, defaults to 2): | |||
| The numeric id representing ending-of-sentence symbol. | |||
| pad_token_id (:obj:`int`, defaults to 1): | |||
| The numeric id representing padding symbol. | |||
| hidden_act (:obj:`str`, defaults to :obj:`"gelu"`): | |||
| Activation inside PLM. | |||
| hidden_dropout_prob (:obj:`float`, defaults to 0.1): | |||
| The dropout ratio for activation states inside PLM. | |||
| hidden_size (:obj:`int`, defaults to 1024): | |||
| The dimensionality of PLM. | |||
| initializer_range (:obj:`float`, defaults to 0.02): | |||
| The hyper-parameter for initializing PLM. | |||
| intermediate_size (:obj:`int`, defaults to 4096): | |||
| The dimensionality of PLM inside feed-forward block. | |||
| layer_norm_eps (:obj:`float`, defaults to 1e-5): | |||
| The value for setting epsilon to avoid zero-division inside | |||
| layer normalization. | |||
| max_position_embeddings: (:obj:`int`, defaults to 512): | |||
| The maximum value for identifying the length of input sequence. | |||
| num_attention_heads (:obj:`int`, defaults to 16): | |||
| The number of attention heads inside multi-head attention layer. | |||
| num_hidden_layers (:obj:`int`, defaults to 24): | |||
| The number of layers inside PLM. | |||
| type_vocab_size (:obj:`int`, defaults to 1): | |||
| The number of type embeddings. | |||
| use_cache (:obj:`bool`, defaults to :obj:`True`): | |||
| Whether to use cached buffer to initialize PLM. | |||
| vocab_size (:obj:`int`, defaults to 250002): | |||
| The size of vocabulary. | |||
| mlp_hidden_sizes (:obj:`List[int]`, defaults to `[3072, 1024]`): | |||
| The size of hidden states inside MLP. | |||
| mlp_act (:obj:`str`, defaults to :obj:`"tanh"`): | |||
| Activation inside MLP. | |||
| mlp_final_act (:obj:`str`, `optional`, defaults to :obj:`None`): | |||
| Activation at the end of MLP. | |||
| mlp_dropout (:obj:`float`, defaults to 0.1): | |||
| The dropout ratio for MLP. | |||
| """ | |||
| super().__init__(**kwargs) | |||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
| self.bos_token_id = bos_token_id | |||
| self.eos_token_id = eos_token_id | |||
| self.pad_token_id = pad_token_id | |||
| self.hidden_act = hidden_act | |||
| self.hidden_dropout_prob = hidden_dropout_prob | |||
| self.hidden_size = hidden_size | |||
| self.initializer_range = initializer_range | |||
| self.intermediate_size = intermediate_size | |||
| self.layer_norm_eps = layer_norm_eps | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.num_attention_heads = num_attention_heads | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.type_vocab_size = type_vocab_size | |||
| self.use_cache = use_cache | |||
| self.vocab_size = vocab_size | |||
| self.mlp_hidden_sizes = mlp_hidden_sizes | |||
| self.mlp_act = mlp_act | |||
| self.mlp_final_act = mlp_final_act | |||
| self.mlp_dropout = mlp_dropout | |||
| self.encoder_config = XLMRobertaConfig( | |||
| bos_token_id=self.bos_token_id, | |||
| eos_token_id=self.eos_token_id, | |||
| pad_token_id=self.pad_token_id, | |||
| vocab_size=self.vocab_size, | |||
| hidden_size=self.hidden_size, | |||
| num_hidden_layers=self.num_hidden_layers, | |||
| num_attention_heads=self.num_attention_heads, | |||
| intermediate_size=self.intermediate_size, | |||
| hidden_act=self.hidden_act, | |||
| hidden_dropout_prob=self.hidden_dropout_prob, | |||
| attention_probs_dropout_prob=self.attention_probs_dropout_prob, | |||
| max_position_embeddings=self.max_position_embeddings, | |||
| type_vocab_size=self.type_vocab_size, | |||
| initializer_range=self.initializer_range, | |||
| layer_norm_eps=self.layer_norm_eps, | |||
| use_cache=self.use_cache) | |||
| self.encoder = XLMRobertaModel( | |||
| self.encoder_config, add_pooling_layer=False) | |||
| self.layerwise_attention = LayerwiseAttention( | |||
| num_layers=self.num_hidden_layers + 1, | |||
| model_dim=self.hidden_size, | |||
| dropout=self.mlp_dropout) | |||
| self.estimator = FeedForward( | |||
| in_dim=self.hidden_size, | |||
| out_dim=1, | |||
| hidden_sizes=self.mlp_hidden_sizes, | |||
| activations=self.mlp_act, | |||
| final_activation=self.mlp_final_act, | |||
| dropout=self.mlp_dropout) | |||
| return | |||
| def forward(self, input_sentences: List[torch.Tensor]): | |||
| input_ids = self.combine_input_sentences(input_sentences) | |||
| attention_mask = input_ids.ne(self.pad_token_id).long() | |||
| outputs = self.encoder( | |||
| input_ids=input_ids, | |||
| attention_mask=attention_mask, | |||
| output_hidden_states=True, | |||
| return_dict=True) | |||
| mix_states = self.layerwise_attention(outputs['hidden_states'], | |||
| attention_mask) | |||
| pred = self.estimator(mix_states) | |||
| return pred.squeeze(dim=-1) | |||
| def load_checkpoint(self, path: str): | |||
| state_dict = torch.load(path) | |||
| self.load_state_dict(state_dict) | |||
| logger.info('Loading checkpoint parameters from %s' % path) | |||
| return | |||
| def combine_input_sentences(self, input_sent_groups: List[torch.Tensor]): | |||
| for input_sent_group in input_sent_groups[1:]: | |||
| input_sent_group[:, 0] = self.eos_token_id | |||
| if len(input_sent_groups) == 3: | |||
| cutted_sents = self.cut_long_sequences3(input_sent_groups) | |||
| else: | |||
| cutted_sents = self.cut_long_sequences2(input_sent_groups) | |||
| return cutted_sents | |||
| @staticmethod | |||
| def cut_long_sequences2(all_input_concat: List[List[torch.Tensor]], | |||
| maximum_length: int = 512, | |||
| pad_idx: int = 1): | |||
| all_input_concat = list(zip(*all_input_concat)) | |||
| collected_tuples = list() | |||
| for tensor_tuple in all_input_concat: | |||
| all_lens = tuple(len(x) for x in tensor_tuple) | |||
| if sum(all_lens) > maximum_length: | |||
| lengths = dict(enumerate(all_lens)) | |||
| lengths_sorted_idxes = list(x[0] for x in sorted( | |||
| lengths.items(), key=lambda d: d[1], reverse=True)) | |||
| offset = ceil((sum(lengths.values()) - maximum_length) / 2) | |||
| if min(all_lens) > (maximum_length | |||
| // 2) and min(all_lens) > offset: | |||
| lengths = dict((k, v - offset) for k, v in lengths.items()) | |||
| else: | |||
| lengths[lengths_sorted_idxes[ | |||
| 0]] = maximum_length - lengths[lengths_sorted_idxes[1]] | |||
| new_lens = list(lengths[k] | |||
| for k in range(0, len(tensor_tuple))) | |||
| new_tensor_tuple = tuple( | |||
| x[:y] for x, y in zip(tensor_tuple, new_lens)) | |||
| for x, y in zip(new_tensor_tuple, tensor_tuple): | |||
| x[-1] = y[-1] | |||
| collected_tuples.append(new_tensor_tuple) | |||
| else: | |||
| collected_tuples.append(tensor_tuple) | |||
| concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples) | |||
| all_input_concat_padded = pad_sequence( | |||
| concat_tensor, batch_first=True, padding_value=pad_idx) | |||
| return all_input_concat_padded | |||
| @staticmethod | |||
| def cut_long_sequences3(all_input_concat: List[List[torch.Tensor]], | |||
| maximum_length: int = 512, | |||
| pad_idx: int = 1): | |||
| all_input_concat = list(zip(*all_input_concat)) | |||
| collected_tuples = list() | |||
| for tensor_tuple in all_input_concat: | |||
| all_lens = tuple(len(x) for x in tensor_tuple) | |||
| if sum(all_lens) > maximum_length: | |||
| lengths = dict(enumerate(all_lens)) | |||
| lengths_sorted_idxes = list(x[0] for x in sorted( | |||
| lengths.items(), key=lambda d: d[1], reverse=True)) | |||
| offset = ceil((sum(lengths.values()) - maximum_length) / 3) | |||
| if min(all_lens) > (maximum_length | |||
| // 3) and min(all_lens) > offset: | |||
| lengths = dict((k, v - offset) for k, v in lengths.items()) | |||
| else: | |||
| while sum(lengths.values()) > maximum_length: | |||
| if lengths[lengths_sorted_idxes[0]] > lengths[ | |||
| lengths_sorted_idxes[1]]: | |||
| offset = maximum_length - lengths[ | |||
| lengths_sorted_idxes[1]] - lengths[ | |||
| lengths_sorted_idxes[2]] | |||
| if offset > lengths[lengths_sorted_idxes[1]]: | |||
| lengths[lengths_sorted_idxes[0]] = offset | |||
| else: | |||
| lengths[lengths_sorted_idxes[0]] = lengths[ | |||
| lengths_sorted_idxes[1]] | |||
| elif lengths[lengths_sorted_idxes[0]] == lengths[ | |||
| lengths_sorted_idxes[1]] > lengths[ | |||
| lengths_sorted_idxes[2]]: | |||
| offset = (maximum_length | |||
| - lengths[lengths_sorted_idxes[2]]) // 2 | |||
| if offset > lengths[lengths_sorted_idxes[2]]: | |||
| lengths[lengths_sorted_idxes[0]] = lengths[ | |||
| lengths_sorted_idxes[1]] = offset | |||
| else: | |||
| lengths[lengths_sorted_idxes[0]] = lengths[ | |||
| lengths_sorted_idxes[1]] = lengths[ | |||
| lengths_sorted_idxes[2]] | |||
| else: | |||
| lengths[lengths_sorted_idxes[0]] = lengths[ | |||
| lengths_sorted_idxes[1]] = lengths[ | |||
| lengths_sorted_idxes[ | |||
| 2]] = maximum_length // 3 | |||
| new_lens = list(lengths[k] for k in range(0, len(lengths))) | |||
| new_tensor_tuple = tuple( | |||
| x[:y] for x, y in zip(tensor_tuple, new_lens)) | |||
| for x, y in zip(new_tensor_tuple, tensor_tuple): | |||
| x[-1] = y[-1] | |||
| collected_tuples.append(new_tensor_tuple) | |||
| else: | |||
| collected_tuples.append(tensor_tuple) | |||
| concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples) | |||
| all_input_concat_padded = pad_sequence( | |||
| concat_tensor, batch_first=True, padding_value=pad_idx) | |||
| return all_input_concat_padded | |||
| @@ -801,6 +801,11 @@ TASK_OUTPUTS = { | |||
| # ] | |||
| # } | |||
| Tasks.product_segmentation: [OutputKeys.MASKS], | |||
| # { | |||
| # 'scores': [0.1, 0.2, 0.3, ...] | |||
| # } | |||
| Tasks.translation_evaluation: [OutputKeys.SCORES] | |||
| } | |||
| @@ -183,6 +183,11 @@ TASK_INPUTS = { | |||
| 'query_set': InputType.LIST, | |||
| 'support_set': InputType.LIST, | |||
| }, | |||
| Tasks.translation_evaluation: { | |||
| 'hyp': InputType.LIST, | |||
| 'src': InputType.LIST, | |||
| 'ref': InputType.LIST, | |||
| }, | |||
| # ============ audio tasks =================== | |||
| Tasks.auto_speech_recognition: | |||
| @@ -217,6 +217,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_swin-t_referring_video-object-segmentation'), | |||
| Tasks.video_summarization: (Pipelines.video_summarization, | |||
| 'damo/cv_googlenet_pgl-video-summarization'), | |||
| Tasks.translation_evaluation: | |||
| (Pipelines.translation_evaluation, | |||
| 'damo/nlp_unite_mup_translation_evaluation_multilingual_large'), | |||
| } | |||
| @@ -32,6 +32,7 @@ if TYPE_CHECKING: | |||
| from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | |||
| from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline | |||
| from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline | |||
| from .translation_evaluation_pipeline import TranslationEvaluationPipeline | |||
| else: | |||
| _import_structure = { | |||
| @@ -77,6 +78,7 @@ else: | |||
| ['CodeGeeXCodeTranslationPipeline'], | |||
| 'codegeex_code_generation_pipeline': | |||
| ['CodeGeeXCodeGenerationPipeline'], | |||
| 'translation_evaluation_pipeline': ['TranslationEvaluationPipeline'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,111 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from enum import Enum | |||
| from typing import Any, Dict, List, Optional, Union | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.nlp.unite.configuration_unite import EvaluationMode | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import InputModel, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import (Preprocessor, | |||
| TranslationEvaluationPreprocessor) | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger(__name__) | |||
| __all__ = ['TranslationEvaluationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.translation_evaluation, module_name=Pipelines.translation_evaluation) | |||
| class TranslationEvaluationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: InputModel, | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| eval_mode: EvaluationMode = EvaluationMode.SRC_REF, | |||
| **kwargs): | |||
| r"""Build a translation pipeline with a model dir or a model id in the model hub. | |||
| Args: | |||
| model: A Model instance. | |||
| eval_mode: Evaluation mode, choosing one from `"EvaluationMode.SRC_REF"`, | |||
| `"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the | |||
| source/reference/source+reference can be presented during evaluation. | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor) | |||
| self.eval_mode = eval_mode | |||
| self.checking_eval_mode() | |||
| self.preprocessor = TranslationEvaluationPreprocessor( | |||
| self.model.model_dir, | |||
| self.eval_mode) if preprocessor is None else preprocessor | |||
| self.model.load_checkpoint( | |||
| osp.join(self.model.model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||
| self.model.eval() | |||
| return | |||
| def checking_eval_mode(self): | |||
| if self.eval_mode == EvaluationMode.SRC: | |||
| logger.info('Evaluation mode: source-only') | |||
| elif self.eval_mode == EvaluationMode.REF: | |||
| logger.info('Evaluation mode: reference-only') | |||
| elif self.eval_mode == EvaluationMode.SRC_REF: | |||
| logger.info('Evaluation mode: source-reference-combined') | |||
| else: | |||
| raise ValueError( | |||
| 'Evaluation mode should be one choice among' | |||
| '\'EvaluationMode.SRC\', \'EvaluationMode.REF\', and' | |||
| '\'EvaluationMode.SRC_REF\'.') | |||
| def change_eval_mode(self, | |||
| eval_mode: EvaluationMode = EvaluationMode.SRC_REF): | |||
| logger.info('Changing the evaluation mode.') | |||
| self.eval_mode = eval_mode | |||
| self.checking_eval_mode() | |||
| self.preprocessor.eval_mode = eval_mode | |||
| return | |||
| def __call__(self, input_dict: Dict[str, Union[str, List[str]]], **kwargs): | |||
| r"""Implementation of __call__ function. | |||
| Args: | |||
| input_dict: The formatted dict containing the inputted sentences. | |||
| An example of the formatted dict: | |||
| ``` | |||
| input_dict = { | |||
| 'hyp': [ | |||
| 'This is a sentence.', | |||
| 'This is another sentence.', | |||
| ], | |||
| 'src': [ | |||
| '这是个句子。', | |||
| '这是另一个句子。', | |||
| ], | |||
| 'ref': [ | |||
| 'It is a sentence.', | |||
| 'It is another sentence.', | |||
| ] | |||
| } | |||
| ``` | |||
| """ | |||
| return super().__call__(input=input_dict, **kwargs) | |||
| def forward(self, | |||
| input_ids: List[torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
| return self.model(input_ids) | |||
| def postprocess(self, output: torch.Tensor) -> Dict[str, Any]: | |||
| result = {OutputKeys.SCORES: output.cpu().tolist()} | |||
| return result | |||
| @@ -33,7 +33,8 @@ if TYPE_CHECKING: | |||
| DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, | |||
| DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor, | |||
| TableQuestionAnsweringPreprocessor, NERPreprocessorViet, | |||
| NERPreprocessorThai, WordSegmentationPreprocessorThai) | |||
| NERPreprocessorThai, WordSegmentationPreprocessorThai, | |||
| TranslationEvaluationPreprocessor) | |||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | |||
| else: | |||
| @@ -72,7 +73,8 @@ else: | |||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
| 'DialogStateTrackingPreprocessor', | |||
| 'ConversationalTextToSqlPreprocessor', | |||
| 'TableQuestionAnsweringPreprocessor' | |||
| 'TableQuestionAnsweringPreprocessor', | |||
| 'TranslationEvaluationPreprocessor' | |||
| ], | |||
| } | |||
| @@ -28,6 +28,7 @@ if TYPE_CHECKING: | |||
| from .space_T_en import ConversationalTextToSqlPreprocessor | |||
| from .space_T_cn import TableQuestionAnsweringPreprocessor | |||
| from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | |||
| from .translation_evaluation_preprocessor import TranslationEvaluationPreprocessor | |||
| else: | |||
| _import_structure = { | |||
| 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], | |||
| @@ -76,6 +77,8 @@ else: | |||
| ], | |||
| 'space_T_en': ['ConversationalTextToSqlPreprocessor'], | |||
| 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], | |||
| 'translation_evaluation_preprocessor': | |||
| ['TranslationEvaluationPreprocessor'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, List, Union | |||
| from transformers import AutoTokenizer | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.models.nlp.unite.configuration_unite import EvaluationMode | |||
| from modelscope.preprocessors import Preprocessor | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.utils.constant import Fields, ModeKeys | |||
| from .transformers_tokenizer import NLPTokenizer | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.translation_evaluation) | |||
| class TranslationEvaluationPreprocessor(Preprocessor): | |||
| r"""The tokenizer preprocessor used for translation evaluation. | |||
| """ | |||
| def __init__(self, | |||
| model_dir: str, | |||
| eval_mode: EvaluationMode, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| r"""preprocess the data via the vocab file from the `model_dir` path | |||
| Args: | |||
| model_dir: A Model instance. | |||
| eval_mode: Evaluation mode, choosing one from `"EvaluationMode.SRC_REF"`, | |||
| `"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the | |||
| source/reference/source+reference can be presented during evaluation. | |||
| """ | |||
| super().__init__(mode=mode) | |||
| self.tokenizer = NLPTokenizer( | |||
| model_dir=model_dir, use_fast=False, tokenize_kwargs=kwargs) | |||
| self.eval_mode = eval_mode | |||
| return | |||
| def __call__(self, input_dict: Dict[str, Any]) -> List[List[str]]: | |||
| if self.eval_mode == EvaluationMode.SRC and 'src' not in input_dict.keys( | |||
| ): | |||
| raise ValueError( | |||
| 'Source sentences are required for source-only evaluation mode.' | |||
| ) | |||
| if self.eval_mode == EvaluationMode.REF and 'ref' not in input_dict.keys( | |||
| ): | |||
| raise ValueError( | |||
| 'Reference sentences are required for reference-only evaluation mode.' | |||
| ) | |||
| if self.eval_mode == EvaluationMode.SRC_REF and ( | |||
| 'src' not in input_dict.keys() | |||
| or 'ref' not in input_dict.keys()): | |||
| raise ValueError( | |||
| 'Source and reference sentences are both required for source-reference-combined evaluation mode.' | |||
| ) | |||
| if type(input_dict['hyp']) == str: | |||
| input_dict['hyp'] = [input_dict['hyp']] | |||
| if (self.eval_mode == EvaluationMode.SRC or self.eval_mode | |||
| == EvaluationMode.SRC_REF) and type(input_dict['src']) == str: | |||
| input_dict['src'] = [input_dict['src']] | |||
| if (self.eval_mode == EvaluationMode.REF or self.eval_mode | |||
| == EvaluationMode.SRC_REF) and type(input_dict['ref']) == str: | |||
| input_dict['ref'] = [input_dict['ref']] | |||
| output_sents = [ | |||
| self.tokenizer( | |||
| input_dict['hyp'], return_tensors='pt', | |||
| padding=True)['input_ids'] | |||
| ] | |||
| if self.eval_mode == EvaluationMode.SRC or self.eval_mode == EvaluationMode.SRC_REF: | |||
| output_sents += [ | |||
| self.tokenizer( | |||
| input_dict['src'], return_tensors='pt', | |||
| padding=True)['input_ids'] | |||
| ] | |||
| if self.eval_mode == EvaluationMode.REF or self.eval_mode == EvaluationMode.SRC_REF: | |||
| output_sents += [ | |||
| self.tokenizer( | |||
| input_dict['ref'], return_tensors='pt', | |||
| padding=True)['input_ids'] | |||
| ] | |||
| return output_sents | |||
| @@ -133,6 +133,7 @@ class NLPTasks(object): | |||
| document_segmentation = 'document-segmentation' | |||
| extractive_summarization = 'extractive-summarization' | |||
| feature_extraction = 'feature-extraction' | |||
| translation_evaluation = 'translation-evaluation' | |||
| class AudioTasks(object): | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.models.nlp.unite.configuration_unite import EvaluationMode | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.translation_evaluation | |||
| self.model_id_large = 'damo/nlp_unite_mup_translation_evaluation_multilingual_large' | |||
| self.model_id_base = 'damo/nlp_unite_mup_translation_evaluation_multilingual_base' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name_for_unite_large(self): | |||
| input_dict = { | |||
| 'hyp': [ | |||
| 'This is a sentence.', | |||
| 'This is another sentence.', | |||
| ], | |||
| 'src': [ | |||
| '这是个句子。', | |||
| '这是另一个句子。', | |||
| ], | |||
| 'ref': [ | |||
| 'It is a sentence.', | |||
| 'It is another sentence.', | |||
| ] | |||
| } | |||
| pipeline_ins = pipeline(self.task, model=self.model_id_large) | |||
| print(pipeline_ins(input_dict)) | |||
| pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC) | |||
| print(pipeline_ins(input_dict)) | |||
| pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF) | |||
| print(pipeline_ins(input_dict)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name_for_unite_base(self): | |||
| input_dict = { | |||
| 'hyp': [ | |||
| 'This is a sentence.', | |||
| 'This is another sentence.', | |||
| ], | |||
| 'src': [ | |||
| '这是个句子。', | |||
| '这是另一个句子。', | |||
| ], | |||
| 'ref': [ | |||
| 'It is a sentence.', | |||
| 'It is another sentence.', | |||
| ] | |||
| } | |||
| pipeline_ins = pipeline(self.task, model=self.model_id_base) | |||
| print(pipeline_ins(input_dict)) | |||
| pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC) | |||
| print(pipeline_ins(input_dict)) | |||
| pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF) | |||
| print(pipeline_ins(input_dict)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||