Browse Source

[to #42322933] Add backbone-head model structure

master
zhangzhicheng.zzc 3 years ago
parent
commit
68fc437044
58 changed files with 2150 additions and 129 deletions
  1. +28
    -19
      configs/nlp/sbert_sentence_similarity.json
  2. +3
    -0
      configs/nlp/sequence_classification_trainer.yaml
  3. +10
    -0
      modelscope/metainfo.py
  4. +1
    -0
      modelscope/metrics/builder.py
  5. +2
    -0
      modelscope/models/__init__.py
  6. +4
    -0
      modelscope/models/base/__init__.py
  7. +48
    -0
      modelscope/models/base/base_head.py
  8. +23
    -1
      modelscope/models/base/base_model.py
  9. +30
    -0
      modelscope/models/base/base_torch_head.py
  10. +55
    -0
      modelscope/models/base/base_torch_model.py
  11. +0
    -23
      modelscope/models/base_torch.py
  12. +29
    -1
      modelscope/models/builder.py
  13. +7
    -4
      modelscope/models/nlp/__init__.py
  14. +4
    -0
      modelscope/models/nlp/backbones/__init__.py
  15. +2
    -0
      modelscope/models/nlp/backbones/space/__init__.py
  16. +0
    -0
      modelscope/models/nlp/backbones/space/model/__init__.py
  17. +0
    -0
      modelscope/models/nlp/backbones/space/model/gen_unified_transformer.py
  18. +0
    -0
      modelscope/models/nlp/backbones/space/model/generator.py
  19. +1
    -1
      modelscope/models/nlp/backbones/space/model/intent_unified_transformer.py
  20. +1
    -1
      modelscope/models/nlp/backbones/space/model/model_base.py
  21. +0
    -0
      modelscope/models/nlp/backbones/space/model/unified_transformer.py
  22. +0
    -0
      modelscope/models/nlp/backbones/space/modules/__init__.py
  23. +0
    -0
      modelscope/models/nlp/backbones/space/modules/embedder.py
  24. +0
    -0
      modelscope/models/nlp/backbones/space/modules/feedforward.py
  25. +0
    -0
      modelscope/models/nlp/backbones/space/modules/functions.py
  26. +0
    -0
      modelscope/models/nlp/backbones/space/modules/multihead_attention.py
  27. +0
    -0
      modelscope/models/nlp/backbones/space/modules/transformer_block.py
  28. +1
    -0
      modelscope/models/nlp/backbones/structbert/__init__.py
  29. +166
    -0
      modelscope/models/nlp/backbones/structbert/adv_utils.py
  30. +131
    -0
      modelscope/models/nlp/backbones/structbert/configuration_sbert.py
  31. +815
    -0
      modelscope/models/nlp/backbones/structbert/modeling_sbert.py
  32. +3
    -0
      modelscope/models/nlp/heads/__init__.py
  33. +44
    -0
      modelscope/models/nlp/heads/sequence_classification_head.py
  34. +1
    -2
      modelscope/models/nlp/palm_for_text_generation.py
  35. +3
    -0
      modelscope/models/nlp/sbert_for_sequence_classification.py
  36. +85
    -0
      modelscope/models/nlp/sequence_classification.py
  37. +0
    -0
      modelscope/models/nlp/space/modules/__init__.py
  38. +10
    -10
      modelscope/models/nlp/space_for_dialog_intent_prediction.py
  39. +10
    -10
      modelscope/models/nlp/space_for_dialog_modeling.py
  40. +4
    -4
      modelscope/models/nlp/space_for_dialog_state_tracking.py
  41. +489
    -0
      modelscope/models/nlp/task_model.py
  42. +3
    -0
      modelscope/outputs.py
  43. +2
    -1
      modelscope/pipelines/builder.py
  44. +1
    -1
      modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py
  45. +1
    -1
      modelscope/pipelines/nlp/dialog_modeling_pipeline.py
  46. +1
    -1
      modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py
  47. +6
    -6
      modelscope/pipelines/nlp/sentiment_classification_pipeline.py
  48. +11
    -0
      modelscope/preprocessors/base.py
  49. +8
    -12
      modelscope/preprocessors/nlp.py
  50. +14
    -7
      modelscope/trainers/trainer.py
  51. +9
    -2
      modelscope/trainers/utils/inference.py
  52. +2
    -0
      modelscope/utils/constant.py
  53. +10
    -5
      modelscope/utils/registry.py
  54. +10
    -6
      modelscope/utils/tensor_utils.py
  55. +28
    -0
      modelscope/utils/utils.py
  56. +1
    -1
      tests/models/test_base_torch.py
  57. +16
    -10
      tests/pipelines/test_sentiment_classification.py
  58. +17
    -0
      tests/trainers/test_trainer_with_nlp.py

+ 28
- 19
configs/nlp/sbert_sentence_similarity.json View File

@@ -6,25 +6,34 @@
"second_sequence": "sentence2" "second_sequence": "sentence2"
}, },
"model": { "model": {
"type": "structbert",
"attention_probs_dropout_prob": 0.1,
"easynlp_version": "0.0.3",
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
"type": "text-classification",
"backbone": {
"type": "structbert",
"prefix": "encoder",
"attention_probs_dropout_prob": 0.1,
"easynlp_version": "0.0.3",
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 21128
},
"head": {
"type": "text-classification",
"hidden_dropout_prob": 0.1,
"hidden_size": 768
}
}, },
"pipeline": { "pipeline": {
"type": "sentence-similarity" "type": "sentence-similarity"


+ 3
- 0
configs/nlp/sequence_classification_trainer.yaml View File

@@ -6,6 +6,9 @@ task: text-classification


model: model:
path: bert-base-sst2 path: bert-base-sst2
backbone:
type: bert
prefix: bert
attention_probs_dropout_prob: 0.1 attention_probs_dropout_prob: 0.1
bos_token_id: 0 bos_token_id: 0
eos_token_id: 2 eos_token_id: 2


+ 10
- 0
modelscope/metainfo.py View File

@@ -33,6 +33,16 @@ class Models(object):
imagen = 'imagen-text-to-image-synthesis' imagen = 'imagen-text-to-image-synthesis'




class TaskModels(object):
# nlp task
text_classification = 'text-classification'


class Heads(object):
# nlp heads
text_classification = 'text-classification'


class Pipelines(object): class Pipelines(object):
""" Names for different pipelines. """ Names for different pipelines.




+ 1
- 0
modelscope/metrics/builder.py View File

@@ -17,6 +17,7 @@ class MetricKeys(object):


task_default_metrics = { task_default_metrics = {
Tasks.sentence_similarity: [Metrics.seq_cls_metric], Tasks.sentence_similarity: [Metrics.seq_cls_metric],
Tasks.sentiment_classification: [Metrics.seq_cls_metric],
Tasks.text_generation: [Metrics.text_gen_metric], Tasks.text_generation: [Metrics.text_gen_metric],
} }




+ 2
- 0
modelscope/models/__init__.py View File

@@ -29,6 +29,8 @@ try:
SbertForZeroShotClassification, SpaceForDialogIntent, SbertForZeroShotClassification, SpaceForDialogIntent,
SpaceForDialogModeling, SpaceForDialogStateTracking, SpaceForDialogModeling, SpaceForDialogStateTracking,
StructBertForMaskedLM, VecoForMaskedLM) StructBertForMaskedLM, VecoForMaskedLM)
from .nlp.heads import (SequenceClassificationHead)
from .nlp.backbones import (SbertModel)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
if str(e) == "No module named 'pytorch'": if str(e) == "No module named 'pytorch'":
pass pass


+ 4
- 0
modelscope/models/base/__init__.py View File

@@ -0,0 +1,4 @@
from .base_head import * # noqa F403
from .base_model import * # noqa F403
from .base_torch_head import * # noqa F403
from .base_torch_model import * # noqa F403

+ 48
- 0
modelscope/models/base/base_head.py View File

@@ -0,0 +1,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABC, abstractmethod
from typing import Dict, List, Union

import numpy as np

from ...utils.config import ConfigDict
from ...utils.logger import get_logger
from .base_model import Model

logger = get_logger()

Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[Dict[str, Tensor], Model]


class Head(ABC):
"""
The head base class is for the tasks head method definition

"""

def __init__(self, **kwargs):
self.config = ConfigDict(kwargs)

@abstractmethod
def forward(self, input: Input) -> Dict[str, Tensor]:
"""
This method will use the output from backbone model to do any
downstream tasks
Args:
input: The tensor output or a model from backbone model
(text generation need a model as input)
Returns: The output from downstream taks
"""
pass

@abstractmethod
def compute_loss(self, outputs: Dict[str, Tensor],
labels) -> Dict[str, Tensor]:
"""
compute loss for head during the finetuning

Args:
outputs (Dict[str, Tensor]): the output from the model forward
Returns: the loss(Dict[str, Tensor]):
"""
pass

modelscope/models/base.py → modelscope/models/base/base_model.py View File

@@ -4,6 +4,8 @@ import os.path as osp
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Optional, Union from typing import Dict, Optional, Union


import numpy as np

from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model from modelscope.models.builder import build_model
from modelscope.utils.config import Config from modelscope.utils.config import Config
@@ -25,6 +27,15 @@ class Model(ABC):


@abstractmethod @abstractmethod
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Run the forward pass for a model.

Args:
input (Dict[str, Tensor]): the dict of the model inputs for the forward method

Returns:
Dict[str, Tensor]: output from the model forward pass
"""
pass pass


def postprocess(self, input: Dict[str, Tensor], def postprocess(self, input: Dict[str, Tensor],
@@ -41,6 +52,15 @@ class Model(ABC):
""" """
return input return input


@classmethod
def _instantiate(cls, **kwargs):
""" Define the instantiation method of a model,default method is by
calling the constructor. Note that in the case of no loading model
process in constructor of a task model, a load_model method is
added, and thus this method is overloaded
"""
return cls(**kwargs)

@classmethod @classmethod
def from_pretrained(cls, def from_pretrained(cls,
model_name_or_path: str, model_name_or_path: str,
@@ -71,6 +91,7 @@ class Model(ABC):
cfg, 'pipeline'), 'pipeline config is missing from config file.' cfg, 'pipeline'), 'pipeline config is missing from config file.'
pipeline_cfg = cfg.pipeline pipeline_cfg = cfg.pipeline
# TODO @wenmeng.zwm may should manually initialize model after model building # TODO @wenmeng.zwm may should manually initialize model after model building

if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
model_cfg.type = model_cfg.model_type model_cfg.type = model_cfg.model_type


@@ -78,7 +99,8 @@ class Model(ABC):


for k, v in kwargs.items(): for k, v in kwargs.items():
model_cfg[k] = v model_cfg[k] = v
model = build_model(model_cfg, task_name)
model = build_model(
model_cfg, task_name=task_name, default_args=kwargs)


# dynamically add pipeline info to model for pipeline inference # dynamically add pipeline info to model for pipeline inference
model.pipeline = pipeline_cfg model.pipeline = pipeline_cfg

+ 30
- 0
modelscope/models/base/base_torch_head.py View File

@@ -0,0 +1,30 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path
import re
from typing import Dict, Optional, Union

import torch
from torch import nn

from ...utils.logger import get_logger
from .base_head import Head

logger = get_logger(__name__)


class TorchHead(Head, torch.nn.Module):
""" Base head interface for pytorch

"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
torch.nn.Module.__init__(self)

def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
raise NotImplementedError

def compute_loss(self, outputs: Dict[str, torch.Tensor],
labels) -> Dict[str, torch.Tensor]:
raise NotImplementedError

+ 55
- 0
modelscope/models/base/base_torch_model.py View File

@@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict, Optional, Union

import torch
from torch import nn

from ...utils.logger import get_logger
from .base_model import Model

logger = get_logger(__name__)


class TorchModel(Model, torch.nn.Module):
""" Base model interface for pytorch

"""

def __init__(self, model_dir=None, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
torch.nn.Module.__init__(self)

def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
raise NotImplementedError

def post_init(self):
"""
A method executed at the end of each model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
"""
self.init_weights()

def init_weights(self):
# Initialize weights
self.apply(self._init_weights)

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def compute_loss(self, outputs: Dict[str, Any], labels):
raise NotImplementedError()

+ 0
- 23
modelscope/models/base_torch.py View File

@@ -1,23 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Dict

import torch

from .base import Model


class TorchModel(Model, torch.nn.Module):
""" Base model interface for pytorch

"""

def __init__(self, model_dir=None, *args, **kwargs):
# init reference: https://stackoverflow.com/questions\
# /9575409/calling-parent-class-init-with-multiple-inheritance-whats-the-right-way
super().__init__(model_dir)
super(Model, self).__init__()

def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
raise NotImplementedError

+ 29
- 1
modelscope/models/builder.py View File

@@ -1,9 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


from modelscope.utils.config import ConfigDict from modelscope.utils.config import ConfigDict
from modelscope.utils.registry import Registry, build_from_cfg
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg


MODELS = Registry('models') MODELS = Registry('models')
BACKBONES = Registry('backbones')
HEADS = Registry('heads')




def build_model(cfg: ConfigDict, def build_model(cfg: ConfigDict,
@@ -19,3 +21,29 @@ def build_model(cfg: ConfigDict,
""" """
return build_from_cfg( return build_from_cfg(
cfg, MODELS, group_key=task_name, default_args=default_args) cfg, MODELS, group_key=task_name, default_args=default_args)


def build_backbone(cfg: ConfigDict,
field: str = None,
default_args: dict = None):
""" build backbone given backbone config dict

Args:
cfg (:obj:`ConfigDict`): config dict for backbone object.
field (str, optional): field, such as CV, NLP's backbone
default_args (dict, optional): Default initialization arguments.
"""
return build_from_cfg(
cfg, BACKBONES, group_key=field, default_args=default_args)


def build_head(cfg: ConfigDict, default_args: dict = None):
""" build head given config dict

Args:
cfg (:obj:`ConfigDict`): config dict for head object.
default_args (dict, optional): Default initialization arguments.
"""

return build_from_cfg(
cfg, HEADS, group_key=cfg[TYPE_NAME], default_args=default_args)

+ 7
- 4
modelscope/models/nlp/__init__.py View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.error import TENSORFLOW_IMPORT_WARNING
from ...utils.error import TENSORFLOW_IMPORT_WARNING
from .backbones import * # noqa F403
from .bert_for_sequence_classification import * # noqa F403 from .bert_for_sequence_classification import * # noqa F403
from .heads import * # noqa F403
from .masked_language import * # noqa F403 from .masked_language import * # noqa F403
from .nncrf_for_named_entity_recognition import * # noqa F403 from .nncrf_for_named_entity_recognition import * # noqa F403
from .palm_for_text_generation import * # noqa F403 from .palm_for_text_generation import * # noqa F403
@@ -9,9 +11,10 @@ from .sbert_for_sentence_similarity import * # noqa F403
from .sbert_for_sentiment_classification import * # noqa F403 from .sbert_for_sentiment_classification import * # noqa F403
from .sbert_for_token_classification import * # noqa F403 from .sbert_for_token_classification import * # noqa F403
from .sbert_for_zero_shot_classification import * # noqa F403 from .sbert_for_zero_shot_classification import * # noqa F403
from .space.dialog_intent_prediction_model import * # noqa F403
from .space.dialog_modeling_model import * # noqa F403
from .space.dialog_state_tracking_model import * # noqa F403
from .sequence_classification import * # noqa F403
from .space_for_dialog_intent_prediction import * # noqa F403
from .space_for_dialog_modeling import * # noqa F403
from .space_for_dialog_state_tracking import * # noqa F403


try: try:
from .csanmt_for_translation import CsanmtForTranslation from .csanmt_for_translation import CsanmtForTranslation


+ 4
- 0
modelscope/models/nlp/backbones/__init__.py View File

@@ -0,0 +1,4 @@
from .space import SpaceGenerator, SpaceModelBase
from .structbert import SbertModel

__all__ = ['SbertModel', 'SpaceGenerator', 'SpaceModelBase']

+ 2
- 0
modelscope/models/nlp/backbones/space/__init__.py View File

@@ -0,0 +1,2 @@
from .model.generator import Generator as SpaceGenerator
from .model.model_base import SpaceModelBase

modelscope/models/nlp/space/model/__init__.py → modelscope/models/nlp/backbones/space/model/__init__.py View File


modelscope/models/nlp/space/model/gen_unified_transformer.py → modelscope/models/nlp/backbones/space/model/gen_unified_transformer.py View File


modelscope/models/nlp/space/model/generator.py → modelscope/models/nlp/backbones/space/model/generator.py View File


modelscope/models/nlp/space/model/intent_unified_transformer.py → modelscope/models/nlp/backbones/space/model/intent_unified_transformer.py View File

@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F


from .....utils.nlp.space.criterions import compute_kl_loss
from ......utils.nlp.space.criterions import compute_kl_loss
from .unified_transformer import UnifiedTransformer from .unified_transformer import UnifiedTransformer





modelscope/models/nlp/space/model/model_base.py → modelscope/models/nlp/backbones/space/model/model_base.py View File

@@ -4,7 +4,7 @@ import os


import torch.nn as nn import torch.nn as nn


from .....utils.constant import ModelFile
from ......utils.constant import ModelFile




class SpaceModelBase(nn.Module): class SpaceModelBase(nn.Module):

modelscope/models/nlp/space/model/unified_transformer.py → modelscope/models/nlp/backbones/space/model/unified_transformer.py View File


modelscope/models/nlp/space/__init__.py → modelscope/models/nlp/backbones/space/modules/__init__.py View File


modelscope/models/nlp/space/modules/embedder.py → modelscope/models/nlp/backbones/space/modules/embedder.py View File


modelscope/models/nlp/space/modules/feedforward.py → modelscope/models/nlp/backbones/space/modules/feedforward.py View File


modelscope/models/nlp/space/modules/functions.py → modelscope/models/nlp/backbones/space/modules/functions.py View File


modelscope/models/nlp/space/modules/multihead_attention.py → modelscope/models/nlp/backbones/space/modules/multihead_attention.py View File


modelscope/models/nlp/space/modules/transformer_block.py → modelscope/models/nlp/backbones/space/modules/transformer_block.py View File


+ 1
- 0
modelscope/models/nlp/backbones/structbert/__init__.py View File

@@ -0,0 +1 @@
from .modeling_sbert import SbertModel

+ 166
- 0
modelscope/models/nlp/backbones/structbert/adv_utils.py View File

@@ -0,0 +1,166 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from .....utils.logger import get_logger

logger = get_logger(__name__)


def _symmetric_kl_div(logits1, logits2, attention_mask=None):
"""
Calclate two logits' the KL div value symmetrically.
:param logits1: The first logit.
:param logits2: The second logit.
:param attention_mask: An optional attention_mask which is used to mask some element out.
This is usually useful in token_classification tasks.
If the shape of logits is [N1, N2, ... Nn, D], the shape of attention_mask should be [N1, N2, ... Nn]
:return: The mean loss.
"""
labels_num = logits1.shape[-1]
KLDiv = nn.KLDivLoss(reduction='none')
loss = torch.sum(
KLDiv(nn.LogSoftmax(dim=-1)(logits1),
nn.Softmax(dim=-1)(logits2)),
dim=-1) + torch.sum(
KLDiv(nn.LogSoftmax(dim=-1)(logits2),
nn.Softmax(dim=-1)(logits1)),
dim=-1)
if attention_mask is not None:
loss = torch.sum(
loss * attention_mask) / torch.sum(attention_mask) / labels_num
else:
loss = torch.mean(loss) / labels_num
return loss


def compute_adv_loss(embedding,
model,
ori_logits,
ori_loss,
adv_grad_factor,
adv_bound=None,
sigma=5e-6,
**kwargs):
"""
Calculate the adv loss of the model.
:param embedding: Original sentense embedding
:param model: The model or the forward function(including decoder/classifier), accept kwargs as input, output logits
:param ori_logits: The original logits outputed from the model function
:param ori_loss: The original loss
:param adv_grad_factor: This factor will be multipled by the KL loss grad and then the result will be added to
the original embedding.
More details please check:https://arxiv.org/abs/1908.04577
The range of this value always be 1e-3~1e-7
:param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding.
If not proveded, 2 * sigma will be used as the adv_bound factor
:param sigma: The std factor used to produce a 0 mean normal distribution.
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor
:param kwargs: the input param used in model function
:return: The original loss adds the adv loss
"""
adv_bound = adv_bound if adv_bound is not None else 2 * sigma
embedding_1 = embedding + embedding.data.new(embedding.size()).normal_(
0, sigma) # 95% in +- 1e-5
kwargs.pop('input_ids')
if 'inputs_embeds' in kwargs:
kwargs.pop('inputs_embeds')
with_attention_mask = False if 'with_attention_mask' not in kwargs else kwargs[
'with_attention_mask']
attention_mask = kwargs['attention_mask']
if not with_attention_mask:
attention_mask = None
if 'with_attention_mask' in kwargs:
kwargs.pop('with_attention_mask')
outputs = model(**kwargs, inputs_embeds=embedding_1)
v1_logits = outputs.logits
loss = _symmetric_kl_div(ori_logits, v1_logits, attention_mask)
emb_grad = torch.autograd.grad(loss, embedding_1)[0].data
emb_grad_norm = emb_grad.norm(
dim=2, keepdim=True, p=float('inf')).max(
1, keepdim=True)[0]
is_nan = torch.any(torch.isnan(emb_grad_norm))
if is_nan:
logger.warning('Nan occured when calculating adv loss.')
return ori_loss
emb_grad = emb_grad / emb_grad_norm
embedding_2 = embedding_1 + adv_grad_factor * emb_grad
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2)
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2)
outputs = model(**kwargs, inputs_embeds=embedding_2)
adv_logits = outputs.logits
adv_loss = _symmetric_kl_div(ori_logits, adv_logits, attention_mask)
return ori_loss + adv_loss


def compute_adv_loss_pair(embedding,
model,
start_logits,
end_logits,
ori_loss,
adv_grad_factor,
adv_bound=None,
sigma=5e-6,
**kwargs):
"""
Calculate the adv loss of the model. This function is used in the pair logits scenerio.
:param embedding: Original sentense embedding
:param model: The model or the forward function(including decoder/classifier), accept kwargs as input, output logits
:param start_logits: The original start logits outputed from the model function
:param end_logits: The original end logits outputed from the model function
:param ori_loss: The original loss
:param adv_grad_factor: This factor will be multipled by the KL loss grad and then the result will be added to
the original embedding.
More details please check:https://arxiv.org/abs/1908.04577
The range of this value always be 1e-3~1e-7
:param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding.
If not proveded, 2 * sigma will be used as the adv_bound factor
:param sigma: The std factor used to produce a 0 mean normal distribution.
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor
:param kwargs: the input param used in model function
:return: The original loss adds the adv loss
"""
adv_bound = adv_bound if adv_bound is not None else 2 * sigma
embedding_1 = embedding + embedding.data.new(embedding.size()).normal_(
0, sigma) # 95% in +- 1e-5
kwargs.pop('input_ids')
if 'inputs_embeds' in kwargs:
kwargs.pop('inputs_embeds')
outputs = model(**kwargs, inputs_embeds=embedding_1)
v1_logits_start, v1_logits_end = outputs.logits
loss = _symmetric_kl_div(start_logits,
v1_logits_start) + _symmetric_kl_div(
end_logits, v1_logits_end)
loss = loss / 2
emb_grad = torch.autograd.grad(loss, embedding_1)[0].data
emb_grad_norm = emb_grad.norm(
dim=2, keepdim=True, p=float('inf')).max(
1, keepdim=True)[0]
is_nan = torch.any(torch.isnan(emb_grad_norm))
if is_nan:
logger.warning('Nan occured when calculating pair adv loss.')
return ori_loss
emb_grad = emb_grad / emb_grad_norm
embedding_2 = embedding_1 + adv_grad_factor * emb_grad
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2)
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2)
outputs = model(**kwargs, inputs_embeds=embedding_2)
adv_logits_start, adv_logits_end = outputs.logits
adv_loss = _symmetric_kl_div(start_logits,
adv_logits_start) + _symmetric_kl_div(
end_logits, adv_logits_end)
return ori_loss + adv_loss

+ 131
- 0
modelscope/models/nlp/backbones/structbert/configuration_sbert.py View File

@@ -0,0 +1,131 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SBERT model configuration, mainly copied from :class:`~transformers.BertConfig` """
from transformers import PretrainedConfig

from .....utils import logger as logging

logger = logging.get_logger(__name__)


class SbertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~sofa.models.SbertModel`.
It is used to instantiate a SBERT model according to the specified arguments.

Configuration objects inherit from :class:`~sofa.utils.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~sofa.utils.PretrainedConfig` for more information.


Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
classifier_dropout (:obj:`float`, `optional`):
The dropout ratio for the classification head.
adv_grad_factor (:obj:`float`, `optional`): This factor will be multipled by the KL loss grad and then
the result will be added to the original embedding.
More details please check:https://arxiv.org/abs/1908.04577
The range of this value always be 1e-3~1e-7
adv_bound (:obj:`float`, `optional`): adv_bound is used to cut the top and the bottom bound of
the produced embedding.
If not proveded, 2 * sigma will be used as the adv_bound factor
sigma (:obj:`float`, `optional`): The std factor used to produce a 0 mean normal distribution.
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor
"""

model_type = 'sbert'

def __init__(self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
position_embedding_type='absolute',
use_cache=True,
classifier_dropout=None,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
# adv_grad_factor, used in adv loss.
# Users can check adv_utils.py for details.
# if adv_grad_factor set to None, no adv loss will not applied to the model.
self.adv_grad_factor = 5e-5 if 'adv_grad_factor' not in kwargs else kwargs[
'adv_grad_factor']
# sigma value, used in adv loss.
self.sigma = 5e-6 if 'sigma' not in kwargs else kwargs['sigma']
# adv_bound value, used in adv loss.
self.adv_bound = 2 * self.sigma if 'adv_bound' not in kwargs else kwargs[
'adv_bound']

+ 815
- 0
modelscope/models/nlp/backbones/structbert/modeling_sbert.py View File

@@ -0,0 +1,815 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput)
from transformers.modeling_utils import (apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer)

from .....metainfo import Models
from .....utils.constant import Fields
from .....utils.logger import get_logger
from ....base import TorchModel
from ....builder import BACKBONES
from .configuration_sbert import SbertConfig

logger = get_logger(__name__)


@BACKBONES.register_module(Fields.nlp, module_name=Models.structbert)
class SbertModel(TorchModel, PreTrainedModel):
"""

The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""

def __init__(self, model_dir=None, add_pooling_layer=True, **config):
"""
Args:
model_dir (str, optional): The model checkpoint directory. Defaults to None.
add_pooling_layer (bool, optional): to decide if pool the output from hidden layer. Defaults to True.
"""
config = SbertConfig(**config)
super().__init__(model_dir)
self.config = config

self.embeddings = SbertEmbeddings(config)
self.encoder = SbertEncoder(config)

self.pooler = SbertPooler(config) if add_pooling_layer else None
self.init_weights()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`
, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers`
with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads,
sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False

if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both input_ids and inputs_embeds at the same time'
)
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError(
'You have to specify either input_ids or inputs_embeds')

batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device

# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[
2] if past_key_values is not None else 0

if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length + past_key_values_length)),
device=device)

if token_type_ids is None:
if hasattr(self.embeddings, 'token_type_ids'):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :
seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size(
)
encoder_hidden_shape = (encoder_batch_size,
encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask)
else:
encoder_extended_attention_mask = None

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask,
self.config.num_hidden_layers)

embedding_output, orignal_embeds = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
return_inputs_embeds=True,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(
sequence_output) if self.pooler is not None else None

if not return_dict:
return (sequence_output,
pooled_output) + encoder_outputs[1:] + (orignal_embeds, )

return BaseModelOutputWithPoolingAndCrossAttentionsWithEmbedding(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
embedding_output=orignal_embeds)

def extract_sequence_outputs(self, outputs):
return outputs['last_hidden_state']

def extract_pooled_outputs(self, outputs):
return outputs['pooler_output']


class SbertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size,
config.hidden_size,
padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config,
'position_embedding_type',
'absolute')
self.register_buffer(
'position_ids',
torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse('1.6.0'):
self.register_buffer(
'token_type_ids',
torch.zeros(
self.position_ids.size(),
dtype=torch.long,
device=self.position_ids.device),
persistent=False,
)

def forward(self,
input_ids=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
past_key_values_length=0,
return_inputs_embeds=False):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]

seq_length = input_shape[1]

if position_ids is None:
position_ids = self.position_ids[:,
past_key_values_length:seq_length
+ past_key_values_length]

# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids
# issue #5664
if token_type_ids is None:
if hasattr(self, 'token_type_ids'):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(
input_shape,
dtype=torch.long,
device=self.position_ids.device)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == 'absolute':
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
if not return_inputs_embeds:
return embeddings
else:
return embeddings, inputs_embeds


class SbertSelfAttention(nn.Module):

def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, 'embedding_size'):
raise ValueError(
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
f'heads ({config.num_attention_heads})')

self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size
/ config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config,
'position_embedding_type',
'absolute')
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query':
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * config.max_position_embeddings - 1,
self.attention_head_size)

self.is_decoder = config.is_decoder

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(
self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(
self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

query_layer = self.transpose_for_scores(mixed_query_layer)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))

if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query':
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(
seq_length, dtype=torch.long,
device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(
seq_length, dtype=torch.long,
device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(
dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == 'relative_key':
relative_position_scores = torch.einsum(
'bhld,lrd->bhlr', query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == 'relative_key_query':
relative_position_scores_query = torch.einsum(
'bhld,lrd->bhlr', query_layer, positional_embedding)
relative_position_scores_key = torch.einsum(
'bhrd,lrd->bhlr', key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

attention_scores = attention_scores / math.sqrt(
self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in SbertModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size, )
context_layer = context_layer.view(*new_context_layer_shape)

outputs = (context_layer,
attention_probs) if output_attentions else (context_layer, )

if self.is_decoder:
outputs = outputs + (past_key_value, )
return outputs


class SbertSelfOutput(nn.Module):

def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states


class SbertAttention(nn.Module):

def __init__(self, config):
super().__init__()
self.self = SbertSelfAttention(config)
self.output = SbertSelfOutput(config)
self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads,
self.self.attention_head_size, self.pruned_heads)

# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(
heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,
) + self_outputs[1:] # add attentions if we output them
return outputs


class SbertIntermediate(nn.Module):

def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states


class SbertOutput(nn.Module):

def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states


class SbertLayer(nn.Module):

def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = SbertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(
f'{self} should be used as a decoder model if cross attention is added'
)
self.crossattention = SbertAttention(config)
self.intermediate = SbertIntermediate(config)
self.output = SbertOutput(config)

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:
2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]

# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[
1:] # add self attentions if we output attention weights

cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, 'crossattention'):
raise ValueError(
f'If `encoder_hidden_states` are passed, {self} has to be instantiated'
f'with cross-attention layers by setting `config.add_cross_attention=True`'
)

# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[
-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[
1:-1] # add cross attentions if we output attention weights

# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value

layer_output = apply_chunking_to_forward(self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output)
outputs = (layer_output, ) + outputs

# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value, )

return outputs

def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output


class SbertEncoder(nn.Module):

def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[SbertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = (
) if output_attentions and self.config.add_cross_attention else None

next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )

layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[
i] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

if use_cache:
logger.warning(
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
)
use_cache = False

def create_custom_forward(module):

def custom_forward(*inputs):
return module(*inputs, past_key_value,
output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1], )
if output_attentions:
all_self_attentions = all_self_attentions + (
layer_outputs[1], )
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (
layer_outputs[2], )

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )

if not return_dict:
return tuple(v for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


class SbertPooler(nn.Module):

def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output


@dataclass
class SbertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~structbert.utils.BertForPreTraining`.

Args:
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Total loss as the sum of the masked language modeling loss and the next sequence prediction
(classification) loss.
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when
``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[torch.FloatTensor] = None
prediction_logits: torch.FloatTensor = None
seq_relationship_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class BaseModelOutputWithPoolingAndCrossAttentionsWithEmbedding(
BaseModelOutputWithPoolingAndCrossAttentions):
embedding_output: torch.FloatTensor = None
logits: Optional[Union[tuple, torch.FloatTensor]] = None
kwargs: dict = None

+ 3
- 0
modelscope/models/nlp/heads/__init__.py View File

@@ -0,0 +1,3 @@
from .sequence_classification_head import SequenceClassificationHead

__all__ = ['SequenceClassificationHead']

+ 44
- 0
modelscope/models/nlp/heads/sequence_classification_head.py View File

@@ -0,0 +1,44 @@
import importlib
from typing import Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from torch import nn

from ....metainfo import Heads
from ....outputs import OutputKeys
from ....utils.constant import Tasks
from ...base import TorchHead
from ...builder import HEADS


@HEADS.register_module(
Tasks.text_classification, module_name=Heads.text_classification)
class SequenceClassificationHead(TorchHead):

def __init__(self, **kwargs):
super().__init__(**kwargs)
config = self.config
self.num_labels = config.num_labels
self.config = config
classifier_dropout = (
config['classifier_dropout'] if config.get('classifier_dropout')
is not None else config['hidden_dropout_prob'])
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config['hidden_size'],
config['num_labels'])

def forward(self, inputs=None):
if isinstance(inputs, dict):
assert inputs.get('pooled_output') is not None
pooled_output = inputs.get('pooled_output')
else:
pooled_output = inputs
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return {OutputKeys.LOGITS: logits}

def compute_loss(self, outputs: Dict[str, torch.Tensor],
labels) -> Dict[str, torch.Tensor]:
logits = outputs[OutputKeys.LOGITS]
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)}

+ 1
- 2
modelscope/models/nlp/palm_for_text_generation.py View File

@@ -2,8 +2,7 @@ from typing import Dict


from ...metainfo import Models from ...metainfo import Models
from ...utils.constant import Tasks from ...utils.constant import Tasks
from ..base import Tensor
from ..base_torch import TorchModel
from ..base import Tensor, TorchModel
from ..builder import MODELS from ..builder import MODELS


__all__ = ['PalmForTextGeneration'] __all__ = ['PalmForTextGeneration']


+ 3
- 0
modelscope/models/nlp/sbert_for_sequence_classification.py View File

@@ -42,6 +42,9 @@ class SbertTextClassfier(SbertPreTrainedModel):
return {'logits': logits, 'loss': loss} return {'logits': logits, 'loss': loss}
return {'logits': logits} return {'logits': logits}


def build(**kwags):
return SbertTextClassfier.from_pretrained(model_dir, **model_args)



class SbertForSequenceClassificationBase(Model): class SbertForSequenceClassificationBase(Model):




+ 85
- 0
modelscope/models/nlp/sequence_classification.py View File

@@ -0,0 +1,85 @@
import os
from typing import Any, Dict

import json
import numpy as np

from ...metainfo import TaskModels
from ...outputs import OutputKeys
from ...utils.constant import Tasks
from ..builder import MODELS
from .task_model import SingleBackboneTaskModelBase

__all__ = ['SequenceClassificationModel']


@MODELS.register_module(
Tasks.sentiment_classification, module_name=TaskModels.text_classification)
@MODELS.register_module(
Tasks.text_classification, module_name=TaskModels.text_classification)
class SequenceClassificationModel(SingleBackboneTaskModelBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the sequence classification model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
if 'base_model_prefix' in kwargs:
self._base_model_prefix = kwargs['base_model_prefix']

backbone_cfg = self.cfg.backbone
head_cfg = self.cfg.head

# get the num_labels from label_mapping.json
self.id2label = {}
self.label_path = os.path.join(model_dir, 'label_mapping.json')
if os.path.exists(self.label_path):
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.id2label = {
idx: name
for name, idx in self.label_mapping.items()
}
head_cfg['num_labels'] = len(self.label_mapping)

self.build_backbone(backbone_cfg)
self.build_head(head_cfg)

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
outputs = super().forward(input)
sequence_output, pooled_output = self.extract_backbone_outputs(outputs)
outputs = self.head.forward(pooled_output)
if 'labels' in input:
loss = self.compute_loss(outputs, input['labels'])
outputs.update(loss)
return outputs

def extract_logits(self, outputs):
return outputs[OutputKeys.LOGITS].cpu().detach()

def extract_backbone_outputs(self, outputs):
sequence_output = None
pooled_output = None
if hasattr(self.backbone, 'extract_sequence_outputs'):
sequence_output = self.backbone.extract_sequence_outputs(outputs)
if hasattr(self.backbone, 'extract_pooled_outputs'):
pooled_output = self.backbone.extract_pooled_outputs(outputs)
return sequence_output, pooled_output

def compute_loss(self, outputs, labels):
loss = self.head.compute_loss(outputs, labels)
return loss

def postprocess(self, input, **kwargs):
logits = self.extract_logits(input)
probs = logits.softmax(-1).numpy()
pred = logits.argmax(-1).numpy()
logits = logits.numpy()
res = {
OutputKeys.PREDICTIONS: pred,
OutputKeys.PROBABILITIES: probs,
OutputKeys.LOGITS: logits
}
return res

+ 0
- 0
modelscope/models/nlp/space/modules/__init__.py View File


modelscope/models/nlp/space/dialog_intent_prediction_model.py → modelscope/models/nlp/space_for_dialog_intent_prediction.py View File

@@ -3,15 +3,14 @@
import os import os
from typing import Any, Dict from typing import Any, Dict


from ....metainfo import Models
from ....preprocessors.space.fields.intent_field import IntentBPETextField
from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer
from ....utils.config import Config
from ....utils.constant import ModelFile, Tasks
from ...base import Model, Tensor
from ...builder import MODELS
from .model.generator import Generator
from .model.model_base import SpaceModelBase
from ...metainfo import Models
from ...preprocessors.space.fields.intent_field import IntentBPETextField
from ...trainers.nlp.space.trainer.intent_trainer import IntentTrainer
from ...utils.config import Config
from ...utils.constant import ModelFile, Tasks
from ..base import Model, Tensor
from ..builder import MODELS
from .backbones import SpaceGenerator, SpaceModelBase


__all__ = ['SpaceForDialogIntent'] __all__ = ['SpaceForDialogIntent']


@@ -37,7 +36,8 @@ class SpaceForDialogIntent(Model):
'text_field', 'text_field',
IntentBPETextField(self.model_dir, config=self.config)) IntentBPETextField(self.model_dir, config=self.config))


self.generator = Generator.create(self.config, reader=self.text_field)
self.generator = SpaceGenerator.create(
self.config, reader=self.text_field)
self.model = SpaceModelBase.create( self.model = SpaceModelBase.create(
model_dir=model_dir, model_dir=model_dir,
config=self.config, config=self.config,

modelscope/models/nlp/space/dialog_modeling_model.py → modelscope/models/nlp/space_for_dialog_modeling.py View File

@@ -3,15 +3,14 @@
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional


from ....metainfo import Models
from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField
from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer
from ....utils.config import Config
from ....utils.constant import ModelFile, Tasks
from ...base import Model, Tensor
from ...builder import MODELS
from .model.generator import Generator
from .model.model_base import SpaceModelBase
from ...metainfo import Models
from ...preprocessors.space.fields.gen_field import MultiWOZBPETextField
from ...trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer
from ...utils.config import Config
from ...utils.constant import ModelFile, Tasks
from ..base import Model, Tensor
from ..builder import MODELS
from .backbones import SpaceGenerator, SpaceModelBase


__all__ = ['SpaceForDialogModeling'] __all__ = ['SpaceForDialogModeling']


@@ -35,7 +34,8 @@ class SpaceForDialogModeling(Model):
self.text_field = kwargs.pop( self.text_field = kwargs.pop(
'text_field', 'text_field',
MultiWOZBPETextField(self.model_dir, config=self.config)) MultiWOZBPETextField(self.model_dir, config=self.config))
self.generator = Generator.create(self.config, reader=self.text_field)
self.generator = SpaceGenerator.create(
self.config, reader=self.text_field)
self.model = SpaceModelBase.create( self.model = SpaceModelBase.create(
model_dir=model_dir, model_dir=model_dir,
config=self.config, config=self.config,

modelscope/models/nlp/space/dialog_state_tracking_model.py → modelscope/models/nlp/space_for_dialog_state_tracking.py View File

@@ -2,10 +2,10 @@ import os
from typing import Any, Dict from typing import Any, Dict


from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from ....metainfo import Models
from ....utils.nlp.space.utils_dst import batch_to_device
from ...base import Model, Tensor
from ...builder import MODELS
from ...metainfo import Models
from ...utils.nlp.space.utils_dst import batch_to_device
from ..base import Model, Tensor
from ..builder import MODELS


__all__ = ['SpaceForDialogStateTracking'] __all__ = ['SpaceForDialogStateTracking']



+ 489
- 0
modelscope/models/nlp/task_model.py View File

@@ -0,0 +1,489 @@
import os.path
import re
from abc import ABC
from collections import OrderedDict
from typing import Any, Dict

import torch
from torch import nn

from ...utils.config import ConfigDict
from ...utils.constant import Fields, Tasks
from ...utils.logger import get_logger
from ...utils.utils import if_func_recieve_dict_inputs
from ..base import TorchModel
from ..builder import build_backbone, build_head

logger = get_logger(__name__)

__all__ = ['EncoderDecoderTaskModelBase', 'SingleBackboneTaskModelBase']


def _repr(modules, depth=1):
# model name log level control
if depth == 0:
return modules._get_name()
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = modules.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []

def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s

for key, module in modules._modules.items():
mod_str = _repr(module, depth - 1)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines

main_str = modules._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'

main_str += ')'
return main_str


class BaseTaskModel(TorchModel, ABC):
""" Base task model interface for nlp

"""
# keys to ignore when load missing
_keys_to_ignore_on_load_missing = None
# keys to ignore when load unexpected
_keys_to_ignore_on_load_unexpected = None
# backbone prefix, default None
_backbone_prefix = None

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.cfg = ConfigDict(kwargs)

def __repr__(self):
# only log backbone and head name
depth = 1
return _repr(self, depth)

@classmethod
def _instantiate(cls, **kwargs):
model_dir = kwargs.get('model_dir')
model = cls(**kwargs)
model.load_checkpoint(model_local_dir=model_dir, **kwargs)
return model

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
pass

def load_checkpoint(self,
model_local_dir,
default_dtype=None,
load_state_fn=None,
**kwargs):
"""
Load model checkpoint file and feed the parameters into the model.
Args:
model_local_dir: The actual checkpoint dir on local disk.
default_dtype: Set the default float type by 'torch.set_default_dtype'
load_state_fn: An optional load_state_fn used to load state_dict into the model.

Returns:

"""
# TODO Sharded ckpt
ckpt_file = os.path.join(model_local_dir, 'pytorch_model.bin')
state_dict = torch.load(ckpt_file, map_location='cpu')
if default_dtype is not None:
torch.set_default_dtype(default_dtype)

missing_keys, unexpected_keys, mismatched_keys, error_msgs = self._load_checkpoint(
state_dict,
load_state_fn=load_state_fn,
ignore_mismatched_sizes=True,
_fast_init=True,
)

return {
'missing_keys': missing_keys,
'unexpected_keys': unexpected_keys,
'mismatched_keys': mismatched_keys,
'error_msgs': error_msgs,
}

def _load_checkpoint(
self,
state_dict,
load_state_fn,
ignore_mismatched_sizes,
_fast_init,
):
# Retrieve missing & unexpected_keys
model_state_dict = self.state_dict()
prefix = self._backbone_prefix

# add head prefix
new_state_dict = OrderedDict()
for name, module in state_dict.items():
if not name.startswith(prefix) and not name.startswith('head'):
new_state_dict['.'.join(['head', name])] = module
else:
new_state_dict[name] = module
state_dict = new_state_dict

loaded_keys = [k for k in state_dict.keys()]
expected_keys = list(model_state_dict.keys())

def _fix_key(key):
if 'beta' in key:
return key.replace('beta', 'bias')
if 'gamma' in key:
return key.replace('gamma', 'weight')
return key

original_loaded_keys = loaded_keys
loaded_keys = [_fix_key(key) for key in loaded_keys]

if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(
s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False

# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module

if remove_prefix_from_model:
expected_keys_not_prefixed = [
s for s in expected_keys if not s.startswith(prefix)
]
expected_keys = [
'.'.join(s.split('.')[1:]) if s.startswith(prefix) else s
for s in expected_keys
]
elif add_prefix_to_model:
expected_keys = ['.'.join([prefix, s]) for s in expected_keys]

missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))

if self._keys_to_ignore_on_load_missing is not None:
for pat in self._keys_to_ignore_on_load_missing:
missing_keys = [
k for k in missing_keys if re.search(pat, k) is None
]

if self._keys_to_ignore_on_load_unexpected is not None:
for pat in self._keys_to_ignore_on_load_unexpected:
unexpected_keys = [
k for k in unexpected_keys if re.search(pat, k) is None
]

if _fast_init:
# retrieve unintialized modules and initialize
uninitialized_modules = self.retrieve_modules_from_names(
missing_keys,
prefix=prefix,
add_prefix=add_prefix_to_model,
remove_prefix=remove_prefix_from_model)
for module in uninitialized_modules:
self._init_weights(module)

# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ''
model_to_load = self
if len(prefix) > 0 and not hasattr(self, prefix) and has_prefix_module:
start_prefix = prefix + '.'
if len(prefix) > 0 and hasattr(self, prefix) and not has_prefix_module:
model_to_load = getattr(self, prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError(
'The state dictionary of the model you are trying to load is corrupted. Are you sure it was '
'properly saved?')

def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f'{prefix}.{checkpoint_key}'
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = '.'.join(checkpoint_key.split('.')[1:])

if (model_key in model_state_dict):
model_shape = model_state_dict[model_key].shape
checkpoint_shape = state_dict[checkpoint_key].shape
if (checkpoint_shape != model_shape):
mismatched_keys.append(
(checkpoint_key,
state_dict[checkpoint_key].shape,
model_state_dict[model_key].shape))
del state_dict[checkpoint_key]
return mismatched_keys

def _load_state_dict_into_model(model_to_load, state_dict,
start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

error_msgs = []

if load_state_fn is not None:
load_state_fn(
model_to_load,
state_dict,
prefix=start_prefix,
local_metadata=None,
error_msgs=error_msgs)
else:

def load(module: nn.Module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [],
error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')

load(model_to_load, prefix=start_prefix)

return error_msgs

# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict,
start_prefix)

if len(error_msgs) > 0:
error_msg = '\n\t'.join(error_msgs)
raise RuntimeError(
f'Error(s) in loading state_dict for {self.__class__.__name__}:\n\t{error_msg}'
)

if len(unexpected_keys) > 0:
logger.warning(
f'Some weights of the model checkpoint were not used when'
f' initializing {self.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are'
f' initializing {self.__class__.__name__} from the checkpoint of a model trained on another task or'
' with another architecture (e.g. initializing a BertForSequenceClassification model from a'
' BertForPreTraining model).\n- This IS NOT expected if you are initializing'
f' {self.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical'
' (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).'
)
else:
logger.info(
f'All model checkpoint weights were used when initializing {self.__class__.__name__}.\n'
)
if len(missing_keys) > 0:
logger.warning(
f'Some weights of {self.__class__.__name__} were not initialized from the model checkpoint'
f' and are newly initialized: {missing_keys}\nYou should probably'
' TRAIN this model on a down-stream task to be able to use it for predictions and inference.'
)
elif len(mismatched_keys) == 0:
logger.info(
f'All the weights of {self.__class__.__name__} were initialized from the model checkpoint '
f'If your task is similar to the task the model of the checkpoint'
f' was trained on, you can already use {self.__class__.__name__} for predictions without further'
' training.')
if len(mismatched_keys) > 0:
mismatched_warning = '\n'.join([
f'- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated'
for key, shape1, shape2 in mismatched_keys
])
logger.warning(
f'Some weights of {self.__class__.__name__} were not initialized from the model checkpoint'
f' and are newly initialized because the shapes did not'
f' match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able'
' to use it for predictions and inference.')

return missing_keys, unexpected_keys, mismatched_keys, error_msgs

def retrieve_modules_from_names(self,
names,
prefix=None,
add_prefix=False,
remove_prefix=False):
module_keys = set(['.'.join(key.split('.')[:-1]) for key in names])

# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys = module_keys.union(
set([
'.'.join(key.split('.')[:-2]) for key in names
if key[-1].isdigit()
]))

retrieved_modules = []
# retrieve all modules that has at least one missing weight name
for name, module in self.named_modules():
if remove_prefix:
name = '.'.join(
name.split('.')[1:]) if name.startswith(prefix) else name
elif add_prefix:
name = '.'.join([prefix, name]) if len(name) > 0 else prefix

if name in module_keys:
retrieved_modules.append(module)

return retrieved_modules


class SingleBackboneTaskModelBase(BaseTaskModel):
"""
This is the base class of any single backbone nlp task classes.
"""
# The backbone prefix defaults to "bert"
_backbone_prefix = 'bert'

# The head prefix defaults to "head"
_head_prefix = 'head'

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)

def build_backbone(self, cfg):
if 'prefix' in cfg:
self._backbone_prefix = cfg['prefix']
backbone = build_backbone(cfg, field=Fields.nlp)
setattr(self, cfg['prefix'], backbone)

def build_head(self, cfg):
if 'prefix' in cfg:
self._head_prefix = cfg['prefix']
head = build_head(cfg)
setattr(self, self._head_prefix, head)
return head

@property
def backbone(self):
if 'backbone' != self._backbone_prefix:
return getattr(self, self._backbone_prefix)
return super().__getattr__('backbone')

@property
def head(self):
if 'head' != self._head_prefix:
return getattr(self, self._head_prefix)
return super().__getattr__('head')

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""default forward method is the backbone-only forward"""
if if_func_recieve_dict_inputs(self.backbone.forward, input):
outputs = self.backbone.forward(input)
else:
outputs = self.backbone.forward(**input)
return outputs


class EncoderDecoderTaskModelBase(BaseTaskModel):
"""
This is the base class of encoder-decoder nlp task classes.
"""
# The encoder backbone prefix, default to "encoder"
_encoder_prefix = 'encoder'
# The decoder backbone prefix, default to "decoder"
_decoder_prefix = 'decoder'
# The key in cfg specifing the encoder type
_encoder_key_in_cfg = 'encoder_type'
# The key in cfg specifing the decoder type
_decoder_key_in_cfg = 'decoder_type'

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)

def build_encoder(self):
encoder = build_backbone(
self.cfg,
type_name=self._encoder_key_in_cfg,
task_name=Tasks.backbone)
setattr(self, self._encoder_prefix, encoder)
return encoder

def build_decoder(self):
decoder = build_backbone(
self.cfg,
type_name=self._decoder_key_in_cfg,
task_name=Tasks.backbone)
setattr(self, self._decoder_prefix, decoder)
return decoder

@property
def encoder_(self):
return getattr(self, self._encoder_prefix)

@property
def decoder_(self):
return getattr(self, self._decoder_prefix)

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
if if_func_recieve_dict_inputs(self.encoder_.forward, input):
encoder_outputs = self.encoder_.forward(input)
else:
encoder_outputs = self.encoder_.forward(**input)
decoder_inputs = self.project_decoder_inputs_and_mediate(
input, encoder_outputs)
if if_func_recieve_dict_inputs(self.decoder_.forward, input):
outputs = self.decoder_.forward(decoder_inputs)
else:
outputs = self.decoder_.forward(**decoder_inputs)

return outputs

def project_decoder_inputs_and_mediate(self, input, encoder_outputs):
return {**input, **encoder_outputs}

+ 3
- 0
modelscope/outputs.py View File

@@ -4,6 +4,7 @@ from modelscope.utils.constant import Tasks




class OutputKeys(object): class OutputKeys(object):
LOSS = 'loss'
LOGITS = 'logits' LOGITS = 'logits'
SCORES = 'scores' SCORES = 'scores'
LABEL = 'label' LABEL = 'label'
@@ -22,6 +23,8 @@ class OutputKeys(object):
TRANSLATION = 'translation' TRANSLATION = 'translation'
RESPONSE = 'response' RESPONSE = 'response'
PREDICTION = 'prediction' PREDICTION = 'prediction'
PREDICTIONS = 'predictions'
PROBABILITIES = 'probabilities'
DIALOG_STATES = 'dialog_states' DIALOG_STATES = 'dialog_states'
VIDEO_EMBEDDING = 'video_embedding' VIDEO_EMBEDDING = 'video_embedding'




+ 2
- 1
modelscope/pipelines/builder.py View File

@@ -30,7 +30,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'),
Tasks.sentiment_classification: Tasks.sentiment_classification:
(Pipelines.sentiment_classification, (Pipelines.sentiment_classification,
'damo/nlp_structbert_sentiment-classification_chinese-base'),
'damo/nlp_structbert_sentiment-classification_chinese-base'
), # TODO: revise back after passing the pr
Tasks.image_matting: (Pipelines.image_matting, Tasks.image_matting: (Pipelines.image_matting,
'damo/cv_unet_image-matting'), 'damo/cv_unet_image-matting'),
Tasks.text_classification: (Pipelines.sentiment_analysis, Tasks.text_classification: (Pipelines.sentiment_analysis,


+ 1
- 1
modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py View File

@@ -2,10 +2,10 @@


from typing import Any, Dict, Union from typing import Any, Dict, Union


from modelscope.outputs import OutputKeys
from ...metainfo import Pipelines from ...metainfo import Pipelines
from ...models import Model from ...models import Model
from ...models.nlp import SpaceForDialogIntent from ...models.nlp import SpaceForDialogIntent
from ...outputs import OutputKeys
from ...preprocessors import DialogIntentPredictionPreprocessor from ...preprocessors import DialogIntentPredictionPreprocessor
from ...utils.constant import Tasks from ...utils.constant import Tasks
from ..base import Pipeline from ..base import Pipeline


+ 1
- 1
modelscope/pipelines/nlp/dialog_modeling_pipeline.py View File

@@ -2,10 +2,10 @@


from typing import Dict, Union from typing import Dict, Union


from modelscope.outputs import OutputKeys
from ...metainfo import Pipelines from ...metainfo import Pipelines
from ...models import Model from ...models import Model
from ...models.nlp import SpaceForDialogModeling from ...models.nlp import SpaceForDialogModeling
from ...outputs import OutputKeys
from ...preprocessors import DialogModelingPreprocessor from ...preprocessors import DialogModelingPreprocessor
from ...utils.constant import Tasks from ...utils.constant import Tasks
from ..base import Pipeline, Tensor from ..base import Pipeline, Tensor


+ 1
- 1
modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py View File

@@ -1,8 +1,8 @@
from typing import Any, Dict, Union from typing import Any, Dict, Union


from modelscope.outputs import OutputKeys
from ...metainfo import Pipelines from ...metainfo import Pipelines
from ...models import Model, SpaceForDialogStateTracking from ...models import Model, SpaceForDialogStateTracking
from ...outputs import OutputKeys
from ...preprocessors import DialogStateTrackingPreprocessor from ...preprocessors import DialogStateTrackingPreprocessor
from ...utils.constant import Tasks from ...utils.constant import Tasks
from ..base import Pipeline from ..base import Pipeline


+ 6
- 6
modelscope/pipelines/nlp/sentiment_classification_pipeline.py View File

@@ -6,7 +6,7 @@ import torch
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from ...metainfo import Pipelines from ...metainfo import Pipelines
from ...models import Model from ...models import Model
from ...models.nlp import SbertForSentimentClassification
from ...models.nlp import SequenceClassificationModel
from ...preprocessors import SentimentClassificationPreprocessor from ...preprocessors import SentimentClassificationPreprocessor
from ...utils.constant import Tasks from ...utils.constant import Tasks
from ..base import Pipeline from ..base import Pipeline
@@ -21,7 +21,7 @@ __all__ = ['SentimentClassificationPipeline']
class SentimentClassificationPipeline(Pipeline): class SentimentClassificationPipeline(Pipeline):


def __init__(self, def __init__(self,
model: Union[SbertForSentimentClassification, str],
model: Union[SequenceClassificationModel, str],
preprocessor: SentimentClassificationPreprocessor = None, preprocessor: SentimentClassificationPreprocessor = None,
first_sequence='first_sequence', first_sequence='first_sequence',
second_sequence='second_sequence', second_sequence='second_sequence',
@@ -29,14 +29,14 @@ class SentimentClassificationPipeline(Pipeline):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction


Args: Args:
model (SbertForSentimentClassification): a model instance
model (SequenceClassificationModel): a model instance
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
""" """
assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \
'model must be a single str or SbertForSentimentClassification'
assert isinstance(model, str) or isinstance(model, SequenceClassificationModel), \
'model must be a single str or SentimentClassification'
model = model if isinstance( model = model if isinstance(
model, model,
SbertForSentimentClassification) else Model.from_pretrained(model)
SequenceClassificationModel) else Model.from_pretrained(model)
if preprocessor is None: if preprocessor is None:
preprocessor = SentimentClassificationPreprocessor( preprocessor = SentimentClassificationPreprocessor(
model.model_dir, model.model_dir,


+ 11
- 0
modelscope/preprocessors/base.py View File

@@ -3,12 +3,23 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict from typing import Any, Dict


from modelscope.utils.constant import ModeKeys



class Preprocessor(ABC): class Preprocessor(ABC):


def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._mode = ModeKeys.INFERENCE
pass pass


@abstractmethod @abstractmethod
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
pass pass

@property
def mode(self):
return self._mode

@mode.setter
def mode(self, value):
self._mode = value

+ 8
- 12
modelscope/preprocessors/nlp.py View File

@@ -7,7 +7,7 @@ from transformers import AutoTokenizer


from ..metainfo import Preprocessors from ..metainfo import Preprocessors
from ..models import Model from ..models import Model
from ..utils.constant import Fields, InputFields
from ..utils.constant import Fields, InputFields, ModeKeys
from ..utils.hub import parse_label_mapping from ..utils.hub import parse_label_mapping
from ..utils.type_assert import type_assert from ..utils.type_assert import type_assert
from .base import Preprocessor from .base import Preprocessor
@@ -52,6 +52,7 @@ class NLPPreprocessorBase(Preprocessor):
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
self.tokenize_kwargs = kwargs self.tokenize_kwargs = kwargs
self.tokenizer = self.build_tokenizer(model_dir) self.tokenizer = self.build_tokenizer(model_dir)
self.label2id = parse_label_mapping(self.model_dir)


def build_tokenizer(self, model_dir): def build_tokenizer(self, model_dir):
from sofa import SbertTokenizer from sofa import SbertTokenizer
@@ -83,7 +84,12 @@ class NLPPreprocessorBase(Preprocessor):
text_a = data.get(self.first_sequence) text_a = data.get(self.first_sequence)
text_b = data.get(self.second_sequence, None) text_b = data.get(self.second_sequence, None)


return self.tokenizer(text_a, text_b, **self.tokenize_kwargs)
rst = self.tokenizer(text_a, text_b, **self.tokenize_kwargs)
if self._mode == ModeKeys.TRAIN:
rst = {k: v.squeeze() for k, v in rst.items()}
if self.label2id is not None and 'label' in data:
rst['label'] = self.label2id[str(data['label'])]
return rst




@PREPROCESSORS.register_module( @PREPROCESSORS.register_module(
@@ -200,16 +206,6 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor):
def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):
kwargs['padding'] = 'max_length' kwargs['padding'] = 'max_length'
super().__init__(model_dir, *args, **kwargs) super().__init__(model_dir, *args, **kwargs)
self.label2id = parse_label_mapping(self.model_dir)

@type_assert(object, (str, tuple, Dict))
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]:
rst = super().__call__(data)
rst = {k: v.squeeze() for k, v in rst.items()}
if self.label2id is not None and 'label' in data:
rst['labels'] = []
rst['labels'].append(self.label2id[str(data['label'])])
return rst




@PREPROCESSORS.register_module( @PREPROCESSORS.register_module(


+ 14
- 7
modelscope/trainers/trainer.py View File

@@ -2,6 +2,7 @@
import os.path import os.path
import random import random
import time import time
from collections.abc import Mapping
from distutils.version import LooseVersion from distutils.version import LooseVersion
from functools import partial from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
@@ -16,8 +17,7 @@ from torch.utils.data.distributed import DistributedSampler


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metrics import build_metric, task_default_metrics from modelscope.metrics import build_metric, task_default_metrics
from modelscope.models.base import Model
from modelscope.models.base_torch import TorchModel
from modelscope.models.base import Model, TorchModel
from modelscope.msdatasets.ms_dataset import MsDataset from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.preprocessors import build_preprocessor from modelscope.preprocessors import build_preprocessor
from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.base import Preprocessor
@@ -26,12 +26,13 @@ from modelscope.trainers.hooks.priority import Priority, get_priority
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
from modelscope.trainers.optimizer.builder import build_optimizer from modelscope.trainers.optimizer.builder import build_optimizer
from modelscope.utils.config import Config, ConfigDict from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import (Hubs, ModeKeys, ModelFile, Tasks,
TrainerStages)
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Hubs, ModeKeys,
ModelFile, Tasks, TrainerStages)
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.utils.registry import build_from_cfg from modelscope.utils.registry import build_from_cfg
from modelscope.utils.tensor_utils import torch_default_data_collator from modelscope.utils.tensor_utils import torch_default_data_collator
from modelscope.utils.torch_utils import get_dist_info from modelscope.utils.torch_utils import get_dist_info
from modelscope.utils.utils import if_func_recieve_dict_inputs
from .base import BaseTrainer from .base import BaseTrainer
from .builder import TRAINERS from .builder import TRAINERS
from .default_config import DEFAULT_CONFIG from .default_config import DEFAULT_CONFIG
@@ -79,13 +80,15 @@ class EpochBasedTrainer(BaseTrainer):
optimizers: Tuple[torch.optim.Optimizer, optimizers: Tuple[torch.optim.Optimizer,
torch.optim.lr_scheduler._LRScheduler] = (None, torch.optim.lr_scheduler._LRScheduler] = (None,
None), None),
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
**kwargs): **kwargs):
if isinstance(model, str): if isinstance(model, str):
if os.path.exists(model): if os.path.exists(model):
self.model_dir = model if os.path.isdir( self.model_dir = model if os.path.isdir(
model) else os.path.dirname(model) model) else os.path.dirname(model)
else: else:
self.model_dir = snapshot_download(model)
self.model_dir = snapshot_download(
model, revision=model_revision)
cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION) cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
self.model = self.build_model() self.model = self.build_model()
else: else:
@@ -112,6 +115,8 @@ class EpochBasedTrainer(BaseTrainer):
self.preprocessor = preprocessor self.preprocessor = preprocessor
elif hasattr(self.cfg, 'preprocessor'): elif hasattr(self.cfg, 'preprocessor'):
self.preprocessor = self.build_preprocessor() self.preprocessor = self.build_preprocessor()
if self.preprocessor is not None:
self.preprocessor.mode = ModeKeys.TRAIN
# TODO @wenmeng.zwm add data collator option # TODO @wenmeng.zwm add data collator option
# TODO how to fill device option? # TODO how to fill device option?
self.device = int( self.device = int(
@@ -264,7 +269,8 @@ class EpochBasedTrainer(BaseTrainer):
model = Model.from_pretrained(self.model_dir) model = Model.from_pretrained(self.model_dir)
if not isinstance(model, nn.Module) and hasattr(model, 'model'): if not isinstance(model, nn.Module) and hasattr(model, 'model'):
return model.model return model.model
return model
elif isinstance(model, nn.Module):
return model


def collate_fn(self, data): def collate_fn(self, data):
"""Prepare the input just before the forward function. """Prepare the input just before the forward function.
@@ -307,7 +313,8 @@ class EpochBasedTrainer(BaseTrainer):
model.train() model.train()
self._mode = ModeKeys.TRAIN self._mode = ModeKeys.TRAIN
inputs = self.collate_fn(inputs) inputs = self.collate_fn(inputs)
if not isinstance(model, Model) and isinstance(inputs, dict):
if isinstance(inputs, Mapping) and not if_func_recieve_dict_inputs(
model.forward, inputs):
train_outputs = model.forward(**inputs) train_outputs = model.forward(**inputs)
else: else:
train_outputs = model.forward(inputs) train_outputs = model.forward(inputs)


+ 9
- 2
modelscope/trainers/utils/inference.py View File

@@ -5,6 +5,7 @@ import pickle
import shutil import shutil
import tempfile import tempfile
import time import time
from collections.abc import Mapping


import torch import torch
from torch import distributed as dist from torch import distributed as dist
@@ -12,6 +13,7 @@ from tqdm import tqdm


from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.utils.torch_utils import get_dist_info from modelscope.utils.torch_utils import get_dist_info
from modelscope.utils.utils import if_func_recieve_dict_inputs




def single_gpu_test(model, def single_gpu_test(model,
@@ -36,7 +38,10 @@ def single_gpu_test(model,
if data_collate_fn is not None: if data_collate_fn is not None:
data = data_collate_fn(data) data = data_collate_fn(data)
with torch.no_grad(): with torch.no_grad():
if not isinstance(model, Model):
if isinstance(data,
Mapping) and not if_func_recieve_dict_inputs(
model.forward, data):

result = model(**data) result = model(**data)
else: else:
result = model(data) result = model(data)
@@ -87,7 +92,9 @@ def multi_gpu_test(model,
if data_collate_fn is not None: if data_collate_fn is not None:
data = data_collate_fn(data) data = data_collate_fn(data)
with torch.no_grad(): with torch.no_grad():
if not isinstance(model, Model):
if isinstance(data,
Mapping) and not if_func_recieve_dict_inputs(
model.forward, data):
result = model(**data) result = model(**data)
else: else:
result = model(data) result = model(data)


+ 2
- 0
modelscope/utils/constant.py View File

@@ -57,6 +57,7 @@ class NLPTasks(object):
summarization = 'summarization' summarization = 'summarization'
question_answering = 'question-answering' question_answering = 'question-answering'
zero_shot_classification = 'zero-shot-classification' zero_shot_classification = 'zero-shot-classification'
backbone = 'backbone'




class AudioTasks(object): class AudioTasks(object):
@@ -173,6 +174,7 @@ DEFAULT_DATASET_REVISION = 'master'
class ModeKeys: class ModeKeys:
TRAIN = 'train' TRAIN = 'train'
EVAL = 'eval' EVAL = 'eval'
INFERENCE = 'inference'




class LogKeys: class LogKeys:


+ 10
- 5
modelscope/utils/registry.py View File

@@ -6,6 +6,7 @@ from typing import List, Tuple, Union
from modelscope.utils.import_utils import requires from modelscope.utils.import_utils import requires
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


TYPE_NAME = 'type'
default_group = 'default' default_group = 'default'
logger = get_logger() logger = get_logger()


@@ -159,15 +160,16 @@ def build_from_cfg(cfg,
group_key (str, optional): The name of registry group from which group_key (str, optional): The name of registry group from which
module should be searched. module should be searched.
default_args (dict, optional): Default initialization arguments. default_args (dict, optional): Default initialization arguments.
type_name (str, optional): The name of the type in the config.
Returns: Returns:
object: The constructed object. object: The constructed object.
""" """
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
if TYPE_NAME not in cfg:
if default_args is None or TYPE_NAME not in default_args:
raise KeyError( raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'`cfg` or `default_args` must contain the key "{TYPE_NAME}", '
f'but got {cfg}\n{default_args}') f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry): if not isinstance(registry, Registry):
raise TypeError('registry must be an modelscope.Registry object, ' raise TypeError('registry must be an modelscope.Registry object, '
@@ -184,7 +186,7 @@ def build_from_cfg(cfg,
if group_key is None: if group_key is None:
group_key = default_group group_key = default_group


obj_type = args.pop('type')
obj_type = args.pop(TYPE_NAME)
if isinstance(obj_type, str): if isinstance(obj_type, str):
obj_cls = registry.get(obj_type, group_key=group_key) obj_cls = registry.get(obj_type, group_key=group_key)
if obj_cls is None: if obj_cls is None:
@@ -196,7 +198,10 @@ def build_from_cfg(cfg,
raise TypeError( raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}') f'type must be a str or valid type, but got {type(obj_type)}')
try: try:
return obj_cls(**args)
if hasattr(obj_cls, '_instantiate'):
return obj_cls._instantiate(**args)
else:
return obj_cls(**args)
except Exception as e: except Exception as e:
# Normal TypeError does not print class name. # Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}') raise type(e)(f'{obj_cls.__name__}: {e}')

+ 10
- 6
modelscope/utils/tensor_utils.py View File

@@ -2,6 +2,8 @@
# Part of the implementation is borrowed from huggingface/transformers. # Part of the implementation is borrowed from huggingface/transformers.
from collections.abc import Mapping from collections.abc import Mapping


import numpy as np



def torch_nested_numpify(tensors): def torch_nested_numpify(tensors):
import torch import torch
@@ -27,9 +29,6 @@ def torch_nested_detach(tensors):
def torch_default_data_collator(features): def torch_default_data_collator(features):
# TODO @jiangnana.jnn refine this default data collator # TODO @jiangnana.jnn refine this default data collator
import torch import torch

# if not isinstance(features[0], (dict, BatchEncoding)):
# features = [vars(f) for f in features]
first = features[0] first = features[0]


if isinstance(first, Mapping): if isinstance(first, Mapping):
@@ -40,9 +39,14 @@ def torch_default_data_collator(features):
if 'label' in first and first['label'] is not None: if 'label' in first and first['label'] is not None:
label = first['label'].item() if isinstance( label = first['label'].item() if isinstance(
first['label'], torch.Tensor) else first['label'] first['label'], torch.Tensor) else first['label']
dtype = torch.long if isinstance(label, int) else torch.float
batch['labels'] = torch.tensor([f['label'] for f in features],
dtype=dtype)
# the msdataset return a 0-dimension np.array with a single value, the following part handle this.
if isinstance(label, np.ndarray):
dtype = torch.long if label[(
)].dtype == np.int64 else torch.float
else:
dtype = torch.long if isinstance(label, int) else torch.float
batch['labels'] = torch.tensor(
np.array([f['label'] for f in features]), dtype=dtype)
elif 'label_ids' in first and first['label_ids'] is not None: elif 'label_ids' in first and first['label_ids'] is not None:
if isinstance(first['label_ids'], torch.Tensor): if isinstance(first['label_ids'], torch.Tensor):
batch['labels'] = torch.stack( batch['labels'] = torch.stack(


+ 28
- 0
modelscope/utils/utils.py View File

@@ -0,0 +1,28 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import inspect


def if_func_recieve_dict_inputs(func, inputs):
"""to decide if a func could recieve dict inputs or not

Args:
func (class): the target function to be inspected
inputs (dicts): the inputs that will send to the function

Returns:
bool: if func recieve dict, then recieve True

Examples:
input = {"input_dict":xxx, "attention_masked":xxx},
function(self, inputs) then return True
function(inputs) then return True
function(self, input_dict, attention_masked) then return False
"""
signature = inspect.signature(func)
func_inputs = list(signature.parameters.keys() - set(['self']))
mismatched_inputs = list(set(func_inputs) - set(inputs))
if len(func_inputs) == len(mismatched_inputs):
return True
else:
return False

+ 1
- 1
tests/models/test_base_torch.py View File

@@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F


from modelscope.models.base_torch import TorchModel
from modelscope.models.base import TorchModel




class TorchBaseTest(unittest.TestCase): class TorchBaseTest(unittest.TestCase):


+ 16
- 10
tests/pipelines/test_sentiment_classification.py View File

@@ -3,7 +3,8 @@ import unittest


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model from modelscope.models import Model
from modelscope.models.nlp import SbertForSentimentClassification
from modelscope.models.nlp import (SbertForSentimentClassification,
SequenceClassificationModel)
from modelscope.pipelines import SentimentClassificationPipeline, pipeline from modelscope.pipelines import SentimentClassificationPipeline, pipeline
from modelscope.preprocessors import SentimentClassificationPreprocessor from modelscope.preprocessors import SentimentClassificationPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@@ -18,39 +19,44 @@ class SentimentClassificationTest(unittest.TestCase):
def test_run_with_direct_file_download(self): def test_run_with_direct_file_download(self):
cache_path = snapshot_download(self.model_id) cache_path = snapshot_download(self.model_id)
tokenizer = SentimentClassificationPreprocessor(cache_path) tokenizer = SentimentClassificationPreprocessor(cache_path)
model = SbertForSentimentClassification(
cache_path, tokenizer=tokenizer)
model = SequenceClassificationModel.from_pretrained(
self.model_id, num_labels=2)
pipeline1 = SentimentClassificationPipeline( pipeline1 = SentimentClassificationPipeline(
model, preprocessor=tokenizer) model, preprocessor=tokenizer)
pipeline2 = pipeline( pipeline2 = pipeline(
Tasks.sentiment_classification, Tasks.sentiment_classification,
model=model, model=model,
preprocessor=tokenizer)
preprocessor=tokenizer,
model_revision='beta')
print(f'sentence1: {self.sentence1}\n' print(f'sentence1: {self.sentence1}\n'
f'pipeline1:{pipeline1(input=self.sentence1)}') f'pipeline1:{pipeline1(input=self.sentence1)}')
print() print()
print(f'sentence1: {self.sentence1}\n' print(f'sentence1: {self.sentence1}\n'
f'pipeline1: {pipeline2(input=self.sentence1)}') f'pipeline1: {pipeline2(input=self.sentence1)}')


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id) model = Model.from_pretrained(self.model_id)
tokenizer = SentimentClassificationPreprocessor(model.model_dir) tokenizer = SentimentClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.sentiment_classification, task=Tasks.sentiment_classification,
model=model, model=model,
preprocessor=tokenizer)
preprocessor=tokenizer,
model_revision='beta')
print(pipeline_ins(input=self.sentence1)) print(pipeline_ins(input=self.sentence1))


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self): def test_run_with_model_name(self):
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.sentiment_classification, model=self.model_id)
task=Tasks.sentiment_classification,
model=self.model_id,
model_revision='beta')
print(pipeline_ins(input=self.sentence1)) print(pipeline_ins(input=self.sentence1))


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self): def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.sentiment_classification)
pipeline_ins = pipeline(
task=Tasks.sentiment_classification, model_revision='beta')
print(pipeline_ins(input=self.sentence1)) print(pipeline_ins(input=self.sentence1))






+ 17
- 0
tests/trainers/test_trainer_with_nlp.py View File

@@ -56,6 +56,23 @@ class TestTrainerWithNlp(unittest.TestCase):
for i in range(10): for i in range(10):
self.assertIn(f'epoch_{i+1}.pth', results_files) self.assertIn(f'epoch_{i+1}.pth', results_files)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_backbone_head(self):
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
kwargs = dict(
model=model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
work_dir=self.tmp_dir,
model_revision='beta')

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(10):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_model_and_args(self): def test_trainer_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name tmp_dir = tempfile.TemporaryDirectory().name


Loading…
Cancel
Save