@@ -6,25 +6,34 @@ | |||
"second_sequence": "sentence2" | |||
}, | |||
"model": { | |||
"type": "structbert", | |||
"attention_probs_dropout_prob": 0.1, | |||
"easynlp_version": "0.0.3", | |||
"gradient_checkpointing": false, | |||
"hidden_act": "gelu", | |||
"hidden_dropout_prob": 0.1, | |||
"hidden_size": 768, | |||
"initializer_range": 0.02, | |||
"intermediate_size": 3072, | |||
"layer_norm_eps": 1e-12, | |||
"max_position_embeddings": 512, | |||
"num_attention_heads": 12, | |||
"num_hidden_layers": 12, | |||
"pad_token_id": 0, | |||
"position_embedding_type": "absolute", | |||
"transformers_version": "4.6.0.dev0", | |||
"type_vocab_size": 2, | |||
"use_cache": true, | |||
"vocab_size": 30522 | |||
"type": "text-classification", | |||
"backbone": { | |||
"type": "structbert", | |||
"prefix": "encoder", | |||
"attention_probs_dropout_prob": 0.1, | |||
"easynlp_version": "0.0.3", | |||
"gradient_checkpointing": false, | |||
"hidden_act": "gelu", | |||
"hidden_dropout_prob": 0.1, | |||
"hidden_size": 768, | |||
"initializer_range": 0.02, | |||
"intermediate_size": 3072, | |||
"layer_norm_eps": 1e-12, | |||
"max_position_embeddings": 512, | |||
"num_attention_heads": 12, | |||
"num_hidden_layers": 12, | |||
"pad_token_id": 0, | |||
"position_embedding_type": "absolute", | |||
"transformers_version": "4.6.0.dev0", | |||
"type_vocab_size": 2, | |||
"use_cache": true, | |||
"vocab_size": 21128 | |||
}, | |||
"head": { | |||
"type": "text-classification", | |||
"hidden_dropout_prob": 0.1, | |||
"hidden_size": 768 | |||
} | |||
}, | |||
"pipeline": { | |||
"type": "sentence-similarity" | |||
@@ -6,6 +6,9 @@ task: text-classification | |||
model: | |||
path: bert-base-sst2 | |||
backbone: | |||
type: bert | |||
prefix: bert | |||
attention_probs_dropout_prob: 0.1 | |||
bos_token_id: 0 | |||
eos_token_id: 2 | |||
@@ -33,6 +33,16 @@ class Models(object): | |||
imagen = 'imagen-text-to-image-synthesis' | |||
class TaskModels(object): | |||
# nlp task | |||
text_classification = 'text-classification' | |||
class Heads(object): | |||
# nlp heads | |||
text_classification = 'text-classification' | |||
class Pipelines(object): | |||
""" Names for different pipelines. | |||
@@ -17,6 +17,7 @@ class MetricKeys(object): | |||
task_default_metrics = { | |||
Tasks.sentence_similarity: [Metrics.seq_cls_metric], | |||
Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
Tasks.text_generation: [Metrics.text_gen_metric], | |||
} | |||
@@ -29,6 +29,8 @@ try: | |||
SbertForZeroShotClassification, SpaceForDialogIntent, | |||
SpaceForDialogModeling, SpaceForDialogStateTracking, | |||
StructBertForMaskedLM, VecoForMaskedLM) | |||
from .nlp.heads import (SequenceClassificationHead) | |||
from .nlp.backbones import (SbertModel) | |||
except ModuleNotFoundError as e: | |||
if str(e) == "No module named 'pytorch'": | |||
pass | |||
@@ -0,0 +1,4 @@ | |||
from .base_head import * # noqa F403 | |||
from .base_model import * # noqa F403 | |||
from .base_torch_head import * # noqa F403 | |||
from .base_torch_model import * # noqa F403 |
@@ -0,0 +1,48 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from abc import ABC, abstractmethod | |||
from typing import Dict, List, Union | |||
import numpy as np | |||
from ...utils.config import ConfigDict | |||
from ...utils.logger import get_logger | |||
from .base_model import Model | |||
logger = get_logger() | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
Input = Union[Dict[str, Tensor], Model] | |||
class Head(ABC): | |||
""" | |||
The head base class is for the tasks head method definition | |||
""" | |||
def __init__(self, **kwargs): | |||
self.config = ConfigDict(kwargs) | |||
@abstractmethod | |||
def forward(self, input: Input) -> Dict[str, Tensor]: | |||
""" | |||
This method will use the output from backbone model to do any | |||
downstream tasks | |||
Args: | |||
input: The tensor output or a model from backbone model | |||
(text generation need a model as input) | |||
Returns: The output from downstream taks | |||
""" | |||
pass | |||
@abstractmethod | |||
def compute_loss(self, outputs: Dict[str, Tensor], | |||
labels) -> Dict[str, Tensor]: | |||
""" | |||
compute loss for head during the finetuning | |||
Args: | |||
outputs (Dict[str, Tensor]): the output from the model forward | |||
Returns: the loss(Dict[str, Tensor]): | |||
""" | |||
pass |
@@ -4,6 +4,8 @@ import os.path as osp | |||
from abc import ABC, abstractmethod | |||
from typing import Dict, Optional, Union | |||
import numpy as np | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models.builder import build_model | |||
from modelscope.utils.config import Config | |||
@@ -25,6 +27,15 @@ class Model(ABC): | |||
@abstractmethod | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
""" | |||
Run the forward pass for a model. | |||
Args: | |||
input (Dict[str, Tensor]): the dict of the model inputs for the forward method | |||
Returns: | |||
Dict[str, Tensor]: output from the model forward pass | |||
""" | |||
pass | |||
def postprocess(self, input: Dict[str, Tensor], | |||
@@ -41,6 +52,15 @@ class Model(ABC): | |||
""" | |||
return input | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
""" Define the instantiation method of a model,default method is by | |||
calling the constructor. Note that in the case of no loading model | |||
process in constructor of a task model, a load_model method is | |||
added, and thus this method is overloaded | |||
""" | |||
return cls(**kwargs) | |||
@classmethod | |||
def from_pretrained(cls, | |||
model_name_or_path: str, | |||
@@ -71,6 +91,7 @@ class Model(ABC): | |||
cfg, 'pipeline'), 'pipeline config is missing from config file.' | |||
pipeline_cfg = cfg.pipeline | |||
# TODO @wenmeng.zwm may should manually initialize model after model building | |||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
model_cfg.type = model_cfg.model_type | |||
@@ -78,7 +99,8 @@ class Model(ABC): | |||
for k, v in kwargs.items(): | |||
model_cfg[k] = v | |||
model = build_model(model_cfg, task_name) | |||
model = build_model( | |||
model_cfg, task_name=task_name, default_args=kwargs) | |||
# dynamically add pipeline info to model for pipeline inference | |||
model.pipeline = pipeline_cfg |
@@ -0,0 +1,30 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path | |||
import re | |||
from typing import Dict, Optional, Union | |||
import torch | |||
from torch import nn | |||
from ...utils.logger import get_logger | |||
from .base_head import Head | |||
logger = get_logger(__name__) | |||
class TorchHead(Head, torch.nn.Module): | |||
""" Base head interface for pytorch | |||
""" | |||
def __init__(self, **kwargs): | |||
super().__init__(**kwargs) | |||
torch.nn.Module.__init__(self) | |||
def forward(self, inputs: Dict[str, | |||
torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
raise NotImplementedError | |||
def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
labels) -> Dict[str, torch.Tensor]: | |||
raise NotImplementedError |
@@ -0,0 +1,55 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Optional, Union | |||
import torch | |||
from torch import nn | |||
from ...utils.logger import get_logger | |||
from .base_model import Model | |||
logger = get_logger(__name__) | |||
class TorchModel(Model, torch.nn.Module): | |||
""" Base model interface for pytorch | |||
""" | |||
def __init__(self, model_dir=None, *args, **kwargs): | |||
super().__init__(model_dir, *args, **kwargs) | |||
torch.nn.Module.__init__(self) | |||
def forward(self, inputs: Dict[str, | |||
torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
raise NotImplementedError | |||
def post_init(self): | |||
""" | |||
A method executed at the end of each model initialization, to execute code that needs the model's | |||
modules properly initialized (such as weight initialization). | |||
""" | |||
self.init_weights() | |||
def init_weights(self): | |||
# Initialize weights | |||
self.apply(self._init_weights) | |||
def _init_weights(self, module): | |||
"""Initialize the weights""" | |||
if isinstance(module, nn.Linear): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_(mean=0.0, std=0.02) | |||
if module.bias is not None: | |||
module.bias.data.zero_() | |||
elif isinstance(module, nn.Embedding): | |||
module.weight.data.normal_(mean=0.0, std=0.02) | |||
if module.padding_idx is not None: | |||
module.weight.data[module.padding_idx].zero_() | |||
elif isinstance(module, nn.LayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
def compute_loss(self, outputs: Dict[str, Any], labels): | |||
raise NotImplementedError() |
@@ -1,23 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Dict | |||
import torch | |||
from .base import Model | |||
class TorchModel(Model, torch.nn.Module): | |||
""" Base model interface for pytorch | |||
""" | |||
def __init__(self, model_dir=None, *args, **kwargs): | |||
# init reference: https://stackoverflow.com/questions\ | |||
# /9575409/calling-parent-class-init-with-multiple-inheritance-whats-the-right-way | |||
super().__init__(model_dir) | |||
super(Model, self).__init__() | |||
def forward(self, inputs: Dict[str, | |||
torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
raise NotImplementedError |
@@ -1,9 +1,11 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.utils.config import ConfigDict | |||
from modelscope.utils.registry import Registry, build_from_cfg | |||
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | |||
MODELS = Registry('models') | |||
BACKBONES = Registry('backbones') | |||
HEADS = Registry('heads') | |||
def build_model(cfg: ConfigDict, | |||
@@ -19,3 +21,29 @@ def build_model(cfg: ConfigDict, | |||
""" | |||
return build_from_cfg( | |||
cfg, MODELS, group_key=task_name, default_args=default_args) | |||
def build_backbone(cfg: ConfigDict, | |||
field: str = None, | |||
default_args: dict = None): | |||
""" build backbone given backbone config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for backbone object. | |||
field (str, optional): field, such as CV, NLP's backbone | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
return build_from_cfg( | |||
cfg, BACKBONES, group_key=field, default_args=default_args) | |||
def build_head(cfg: ConfigDict, default_args: dict = None): | |||
""" build head given config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for head object. | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
return build_from_cfg( | |||
cfg, HEADS, group_key=cfg[TYPE_NAME], default_args=default_args) |
@@ -1,6 +1,8 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.utils.error import TENSORFLOW_IMPORT_WARNING | |||
from ...utils.error import TENSORFLOW_IMPORT_WARNING | |||
from .backbones import * # noqa F403 | |||
from .bert_for_sequence_classification import * # noqa F403 | |||
from .heads import * # noqa F403 | |||
from .masked_language import * # noqa F403 | |||
from .nncrf_for_named_entity_recognition import * # noqa F403 | |||
from .palm_for_text_generation import * # noqa F403 | |||
@@ -9,9 +11,10 @@ from .sbert_for_sentence_similarity import * # noqa F403 | |||
from .sbert_for_sentiment_classification import * # noqa F403 | |||
from .sbert_for_token_classification import * # noqa F403 | |||
from .sbert_for_zero_shot_classification import * # noqa F403 | |||
from .space.dialog_intent_prediction_model import * # noqa F403 | |||
from .space.dialog_modeling_model import * # noqa F403 | |||
from .space.dialog_state_tracking_model import * # noqa F403 | |||
from .sequence_classification import * # noqa F403 | |||
from .space_for_dialog_intent_prediction import * # noqa F403 | |||
from .space_for_dialog_modeling import * # noqa F403 | |||
from .space_for_dialog_state_tracking import * # noqa F403 | |||
try: | |||
from .csanmt_for_translation import CsanmtForTranslation | |||
@@ -0,0 +1,4 @@ | |||
from .space import SpaceGenerator, SpaceModelBase | |||
from .structbert import SbertModel | |||
__all__ = ['SbertModel', 'SpaceGenerator', 'SpaceModelBase'] |
@@ -0,0 +1,2 @@ | |||
from .model.generator import Generator as SpaceGenerator | |||
from .model.model_base import SpaceModelBase |
@@ -4,7 +4,7 @@ import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .....utils.nlp.space.criterions import compute_kl_loss | |||
from ......utils.nlp.space.criterions import compute_kl_loss | |||
from .unified_transformer import UnifiedTransformer | |||
@@ -4,7 +4,7 @@ import os | |||
import torch.nn as nn | |||
from .....utils.constant import ModelFile | |||
from ......utils.constant import ModelFile | |||
class SpaceModelBase(nn.Module): |
@@ -0,0 +1 @@ | |||
from .modeling_sbert import SbertModel |
@@ -0,0 +1,166 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
from torch import nn | |||
from .....utils.logger import get_logger | |||
logger = get_logger(__name__) | |||
def _symmetric_kl_div(logits1, logits2, attention_mask=None): | |||
""" | |||
Calclate two logits' the KL div value symmetrically. | |||
:param logits1: The first logit. | |||
:param logits2: The second logit. | |||
:param attention_mask: An optional attention_mask which is used to mask some element out. | |||
This is usually useful in token_classification tasks. | |||
If the shape of logits is [N1, N2, ... Nn, D], the shape of attention_mask should be [N1, N2, ... Nn] | |||
:return: The mean loss. | |||
""" | |||
labels_num = logits1.shape[-1] | |||
KLDiv = nn.KLDivLoss(reduction='none') | |||
loss = torch.sum( | |||
KLDiv(nn.LogSoftmax(dim=-1)(logits1), | |||
nn.Softmax(dim=-1)(logits2)), | |||
dim=-1) + torch.sum( | |||
KLDiv(nn.LogSoftmax(dim=-1)(logits2), | |||
nn.Softmax(dim=-1)(logits1)), | |||
dim=-1) | |||
if attention_mask is not None: | |||
loss = torch.sum( | |||
loss * attention_mask) / torch.sum(attention_mask) / labels_num | |||
else: | |||
loss = torch.mean(loss) / labels_num | |||
return loss | |||
def compute_adv_loss(embedding, | |||
model, | |||
ori_logits, | |||
ori_loss, | |||
adv_grad_factor, | |||
adv_bound=None, | |||
sigma=5e-6, | |||
**kwargs): | |||
""" | |||
Calculate the adv loss of the model. | |||
:param embedding: Original sentense embedding | |||
:param model: The model or the forward function(including decoder/classifier), accept kwargs as input, output logits | |||
:param ori_logits: The original logits outputed from the model function | |||
:param ori_loss: The original loss | |||
:param adv_grad_factor: This factor will be multipled by the KL loss grad and then the result will be added to | |||
the original embedding. | |||
More details please check:https://arxiv.org/abs/1908.04577 | |||
The range of this value always be 1e-3~1e-7 | |||
:param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding. | |||
If not proveded, 2 * sigma will be used as the adv_bound factor | |||
:param sigma: The std factor used to produce a 0 mean normal distribution. | |||
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | |||
:param kwargs: the input param used in model function | |||
:return: The original loss adds the adv loss | |||
""" | |||
adv_bound = adv_bound if adv_bound is not None else 2 * sigma | |||
embedding_1 = embedding + embedding.data.new(embedding.size()).normal_( | |||
0, sigma) # 95% in +- 1e-5 | |||
kwargs.pop('input_ids') | |||
if 'inputs_embeds' in kwargs: | |||
kwargs.pop('inputs_embeds') | |||
with_attention_mask = False if 'with_attention_mask' not in kwargs else kwargs[ | |||
'with_attention_mask'] | |||
attention_mask = kwargs['attention_mask'] | |||
if not with_attention_mask: | |||
attention_mask = None | |||
if 'with_attention_mask' in kwargs: | |||
kwargs.pop('with_attention_mask') | |||
outputs = model(**kwargs, inputs_embeds=embedding_1) | |||
v1_logits = outputs.logits | |||
loss = _symmetric_kl_div(ori_logits, v1_logits, attention_mask) | |||
emb_grad = torch.autograd.grad(loss, embedding_1)[0].data | |||
emb_grad_norm = emb_grad.norm( | |||
dim=2, keepdim=True, p=float('inf')).max( | |||
1, keepdim=True)[0] | |||
is_nan = torch.any(torch.isnan(emb_grad_norm)) | |||
if is_nan: | |||
logger.warning('Nan occured when calculating adv loss.') | |||
return ori_loss | |||
emb_grad = emb_grad / emb_grad_norm | |||
embedding_2 = embedding_1 + adv_grad_factor * emb_grad | |||
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) | |||
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) | |||
outputs = model(**kwargs, inputs_embeds=embedding_2) | |||
adv_logits = outputs.logits | |||
adv_loss = _symmetric_kl_div(ori_logits, adv_logits, attention_mask) | |||
return ori_loss + adv_loss | |||
def compute_adv_loss_pair(embedding, | |||
model, | |||
start_logits, | |||
end_logits, | |||
ori_loss, | |||
adv_grad_factor, | |||
adv_bound=None, | |||
sigma=5e-6, | |||
**kwargs): | |||
""" | |||
Calculate the adv loss of the model. This function is used in the pair logits scenerio. | |||
:param embedding: Original sentense embedding | |||
:param model: The model or the forward function(including decoder/classifier), accept kwargs as input, output logits | |||
:param start_logits: The original start logits outputed from the model function | |||
:param end_logits: The original end logits outputed from the model function | |||
:param ori_loss: The original loss | |||
:param adv_grad_factor: This factor will be multipled by the KL loss grad and then the result will be added to | |||
the original embedding. | |||
More details please check:https://arxiv.org/abs/1908.04577 | |||
The range of this value always be 1e-3~1e-7 | |||
:param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding. | |||
If not proveded, 2 * sigma will be used as the adv_bound factor | |||
:param sigma: The std factor used to produce a 0 mean normal distribution. | |||
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | |||
:param kwargs: the input param used in model function | |||
:return: The original loss adds the adv loss | |||
""" | |||
adv_bound = adv_bound if adv_bound is not None else 2 * sigma | |||
embedding_1 = embedding + embedding.data.new(embedding.size()).normal_( | |||
0, sigma) # 95% in +- 1e-5 | |||
kwargs.pop('input_ids') | |||
if 'inputs_embeds' in kwargs: | |||
kwargs.pop('inputs_embeds') | |||
outputs = model(**kwargs, inputs_embeds=embedding_1) | |||
v1_logits_start, v1_logits_end = outputs.logits | |||
loss = _symmetric_kl_div(start_logits, | |||
v1_logits_start) + _symmetric_kl_div( | |||
end_logits, v1_logits_end) | |||
loss = loss / 2 | |||
emb_grad = torch.autograd.grad(loss, embedding_1)[0].data | |||
emb_grad_norm = emb_grad.norm( | |||
dim=2, keepdim=True, p=float('inf')).max( | |||
1, keepdim=True)[0] | |||
is_nan = torch.any(torch.isnan(emb_grad_norm)) | |||
if is_nan: | |||
logger.warning('Nan occured when calculating pair adv loss.') | |||
return ori_loss | |||
emb_grad = emb_grad / emb_grad_norm | |||
embedding_2 = embedding_1 + adv_grad_factor * emb_grad | |||
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) | |||
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) | |||
outputs = model(**kwargs, inputs_embeds=embedding_2) | |||
adv_logits_start, adv_logits_end = outputs.logits | |||
adv_loss = _symmetric_kl_div(start_logits, | |||
adv_logits_start) + _symmetric_kl_div( | |||
end_logits, adv_logits_end) | |||
return ori_loss + adv_loss |
@@ -0,0 +1,131 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" SBERT model configuration, mainly copied from :class:`~transformers.BertConfig` """ | |||
from transformers import PretrainedConfig | |||
from .....utils import logger as logging | |||
logger = logging.get_logger(__name__) | |||
class SbertConfig(PretrainedConfig): | |||
r""" | |||
This is the configuration class to store the configuration of a :class:`~sofa.models.SbertModel`. | |||
It is used to instantiate a SBERT model according to the specified arguments. | |||
Configuration objects inherit from :class:`~sofa.utils.PretrainedConfig` and can be used to control the model | |||
outputs. Read the documentation from :class:`~sofa.utils.PretrainedConfig` for more information. | |||
Args: | |||
vocab_size (:obj:`int`, `optional`, defaults to 30522): | |||
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the | |||
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or | |||
:class:`~transformers.TFBertModel`. | |||
hidden_size (:obj:`int`, `optional`, defaults to 768): | |||
Dimensionality of the encoder layers and the pooler layer. | |||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12): | |||
Number of hidden layers in the Transformer encoder. | |||
num_attention_heads (:obj:`int`, `optional`, defaults to 12): | |||
Number of attention heads for each attention layer in the Transformer encoder. | |||
intermediate_size (:obj:`int`, `optional`, defaults to 3072): | |||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. | |||
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): | |||
The non-linear activation function (function or string) in the encoder and pooler. If string, | |||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. | |||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): | |||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): | |||
The dropout ratio for the attention probabilities. | |||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512): | |||
The maximum sequence length that this model might ever be used with. Typically set this to something large | |||
just in case (e.g., 512 or 1024 or 2048). | |||
type_vocab_size (:obj:`int`, `optional`, defaults to 2): | |||
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or | |||
:class:`~transformers.TFBertModel`. | |||
initializer_range (:obj:`float`, `optional`, defaults to 0.02): | |||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): | |||
The epsilon used by the layer normalization layers. | |||
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): | |||
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, | |||
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on | |||
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.) | |||
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to | |||
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) | |||
<https://arxiv.org/abs/2009.13658>`__. | |||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): | |||
Whether or not the model should return the last key/values attentions (not used by all models). Only | |||
relevant if ``config.is_decoder=True``. | |||
classifier_dropout (:obj:`float`, `optional`): | |||
The dropout ratio for the classification head. | |||
adv_grad_factor (:obj:`float`, `optional`): This factor will be multipled by the KL loss grad and then | |||
the result will be added to the original embedding. | |||
More details please check:https://arxiv.org/abs/1908.04577 | |||
The range of this value always be 1e-3~1e-7 | |||
adv_bound (:obj:`float`, `optional`): adv_bound is used to cut the top and the bottom bound of | |||
the produced embedding. | |||
If not proveded, 2 * sigma will be used as the adv_bound factor | |||
sigma (:obj:`float`, `optional`): The std factor used to produce a 0 mean normal distribution. | |||
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | |||
""" | |||
model_type = 'sbert' | |||
def __init__(self, | |||
vocab_size=30522, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act='gelu', | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02, | |||
layer_norm_eps=1e-12, | |||
position_embedding_type='absolute', | |||
use_cache=True, | |||
classifier_dropout=None, | |||
**kwargs): | |||
super().__init__(**kwargs) | |||
self.vocab_size = vocab_size | |||
self.hidden_size = hidden_size | |||
self.num_hidden_layers = num_hidden_layers | |||
self.num_attention_heads = num_attention_heads | |||
self.hidden_act = hidden_act | |||
self.intermediate_size = intermediate_size | |||
self.hidden_dropout_prob = hidden_dropout_prob | |||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
self.max_position_embeddings = max_position_embeddings | |||
self.type_vocab_size = type_vocab_size | |||
self.initializer_range = initializer_range | |||
self.layer_norm_eps = layer_norm_eps | |||
self.position_embedding_type = position_embedding_type | |||
self.use_cache = use_cache | |||
self.classifier_dropout = classifier_dropout | |||
# adv_grad_factor, used in adv loss. | |||
# Users can check adv_utils.py for details. | |||
# if adv_grad_factor set to None, no adv loss will not applied to the model. | |||
self.adv_grad_factor = 5e-5 if 'adv_grad_factor' not in kwargs else kwargs[ | |||
'adv_grad_factor'] | |||
# sigma value, used in adv loss. | |||
self.sigma = 5e-6 if 'sigma' not in kwargs else kwargs['sigma'] | |||
# adv_bound value, used in adv loss. | |||
self.adv_bound = 2 * self.sigma if 'adv_bound' not in kwargs else kwargs[ | |||
'adv_bound'] |
@@ -0,0 +1,815 @@ | |||
import math | |||
from dataclasses import dataclass | |||
from typing import Optional, Tuple, Union | |||
import torch | |||
import torch.utils.checkpoint | |||
from packaging import version | |||
from torch import nn | |||
from transformers import PreTrainedModel | |||
from transformers.activations import ACT2FN | |||
from transformers.modeling_outputs import ( | |||
BaseModelOutputWithPastAndCrossAttentions, | |||
BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput) | |||
from transformers.modeling_utils import (apply_chunking_to_forward, | |||
find_pruneable_heads_and_indices, | |||
prune_linear_layer) | |||
from .....metainfo import Models | |||
from .....utils.constant import Fields | |||
from .....utils.logger import get_logger | |||
from ....base import TorchModel | |||
from ....builder import BACKBONES | |||
from .configuration_sbert import SbertConfig | |||
logger = get_logger(__name__) | |||
@BACKBONES.register_module(Fields.nlp, module_name=Models.structbert) | |||
class SbertModel(TorchModel, PreTrainedModel): | |||
""" | |||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of | |||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is | |||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, | |||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. | |||
To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration | |||
set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` | |||
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an | |||
input to the forward pass. | |||
""" | |||
def __init__(self, model_dir=None, add_pooling_layer=True, **config): | |||
""" | |||
Args: | |||
model_dir (str, optional): The model checkpoint directory. Defaults to None. | |||
add_pooling_layer (bool, optional): to decide if pool the output from hidden layer. Defaults to True. | |||
""" | |||
config = SbertConfig(**config) | |||
super().__init__(model_dir) | |||
self.config = config | |||
self.embeddings = SbertEmbeddings(config) | |||
self.encoder = SbertEncoder(config) | |||
self.pooler = SbertPooler(config) if add_pooling_layer else None | |||
self.init_weights() | |||
def get_input_embeddings(self): | |||
return self.embeddings.word_embeddings | |||
def set_input_embeddings(self, value): | |||
self.embeddings.word_embeddings = value | |||
def _prune_heads(self, heads_to_prune): | |||
""" | |||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base | |||
class PreTrainedModel | |||
""" | |||
for layer, heads in heads_to_prune.items(): | |||
self.encoder.layer[layer].attention.prune_heads(heads) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
**kwargs): | |||
r""" | |||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)` | |||
, `optional`): | |||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if | |||
the model is configured as a decoder. | |||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in | |||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` | |||
with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, | |||
sequence_length - 1, embed_size_per_head)`): | |||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. | |||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` | |||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` | |||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. | |||
use_cache (:obj:`bool`, `optional`): | |||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up | |||
decoding (see :obj:`past_key_values`). | |||
""" | |||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |||
output_hidden_states = ( | |||
output_hidden_states if output_hidden_states is not None else | |||
self.config.output_hidden_states) | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
if self.config.is_decoder: | |||
use_cache = use_cache if use_cache is not None else self.config.use_cache | |||
else: | |||
use_cache = False | |||
if input_ids is not None and inputs_embeds is not None: | |||
raise ValueError( | |||
'You cannot specify both input_ids and inputs_embeds at the same time' | |||
) | |||
elif input_ids is not None: | |||
input_shape = input_ids.size() | |||
elif inputs_embeds is not None: | |||
input_shape = inputs_embeds.size()[:-1] | |||
else: | |||
raise ValueError( | |||
'You have to specify either input_ids or inputs_embeds') | |||
batch_size, seq_length = input_shape | |||
device = input_ids.device if input_ids is not None else inputs_embeds.device | |||
# past_key_values_length | |||
past_key_values_length = past_key_values[0][0].shape[ | |||
2] if past_key_values is not None else 0 | |||
if attention_mask is None: | |||
attention_mask = torch.ones( | |||
((batch_size, seq_length + past_key_values_length)), | |||
device=device) | |||
if token_type_ids is None: | |||
if hasattr(self.embeddings, 'token_type_ids'): | |||
buffered_token_type_ids = self.embeddings.token_type_ids[:, : | |||
seq_length] | |||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( | |||
batch_size, seq_length) | |||
token_type_ids = buffered_token_type_ids_expanded | |||
else: | |||
token_type_ids = torch.zeros( | |||
input_shape, dtype=torch.long, device=device) | |||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||
# ourselves in which case we just need to make it broadcastable to all heads. | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( | |||
attention_mask, input_shape, device) | |||
# If a 2D or 3D attention mask is provided for the cross-attention | |||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | |||
if self.config.is_decoder and encoder_hidden_states is not None: | |||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( | |||
) | |||
encoder_hidden_shape = (encoder_batch_size, | |||
encoder_sequence_length) | |||
if encoder_attention_mask is None: | |||
encoder_attention_mask = torch.ones( | |||
encoder_hidden_shape, device=device) | |||
encoder_extended_attention_mask = self.invert_attention_mask( | |||
encoder_attention_mask) | |||
else: | |||
encoder_extended_attention_mask = None | |||
# Prepare head mask if needed | |||
# 1.0 in head_mask indicate we keep the head | |||
# attention_probs has shape bsz x n_heads x N x N | |||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | |||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | |||
head_mask = self.get_head_mask(head_mask, | |||
self.config.num_hidden_layers) | |||
embedding_output, orignal_embeds = self.embeddings( | |||
input_ids=input_ids, | |||
position_ids=position_ids, | |||
token_type_ids=token_type_ids, | |||
inputs_embeds=inputs_embeds, | |||
past_key_values_length=past_key_values_length, | |||
return_inputs_embeds=True, | |||
) | |||
encoder_outputs = self.encoder( | |||
embedding_output, | |||
attention_mask=extended_attention_mask, | |||
head_mask=head_mask, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_extended_attention_mask, | |||
past_key_values=past_key_values, | |||
use_cache=use_cache, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = encoder_outputs[0] | |||
pooled_output = self.pooler( | |||
sequence_output) if self.pooler is not None else None | |||
if not return_dict: | |||
return (sequence_output, | |||
pooled_output) + encoder_outputs[1:] + (orignal_embeds, ) | |||
return BaseModelOutputWithPoolingAndCrossAttentionsWithEmbedding( | |||
last_hidden_state=sequence_output, | |||
pooler_output=pooled_output, | |||
past_key_values=encoder_outputs.past_key_values, | |||
hidden_states=encoder_outputs.hidden_states, | |||
attentions=encoder_outputs.attentions, | |||
cross_attentions=encoder_outputs.cross_attentions, | |||
embedding_output=orignal_embeds) | |||
def extract_sequence_outputs(self, outputs): | |||
return outputs['last_hidden_state'] | |||
def extract_pooled_outputs(self, outputs): | |||
return outputs['pooler_output'] | |||
class SbertEmbeddings(nn.Module): | |||
"""Construct the embeddings from word, position and token_type embeddings.""" | |||
def __init__(self, config): | |||
super().__init__() | |||
self.word_embeddings = nn.Embedding( | |||
config.vocab_size, | |||
config.hidden_size, | |||
padding_idx=config.pad_token_id) | |||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||
config.hidden_size) | |||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, | |||
config.hidden_size) | |||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||
# any TensorFlow checkpoint file | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized | |||
self.position_embedding_type = getattr(config, | |||
'position_embedding_type', | |||
'absolute') | |||
self.register_buffer( | |||
'position_ids', | |||
torch.arange(config.max_position_embeddings).expand((1, -1))) | |||
if version.parse(torch.__version__) > version.parse('1.6.0'): | |||
self.register_buffer( | |||
'token_type_ids', | |||
torch.zeros( | |||
self.position_ids.size(), | |||
dtype=torch.long, | |||
device=self.position_ids.device), | |||
persistent=False, | |||
) | |||
def forward(self, | |||
input_ids=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
inputs_embeds=None, | |||
past_key_values_length=0, | |||
return_inputs_embeds=False): | |||
if input_ids is not None: | |||
input_shape = input_ids.size() | |||
else: | |||
input_shape = inputs_embeds.size()[:-1] | |||
seq_length = input_shape[1] | |||
if position_ids is None: | |||
position_ids = self.position_ids[:, | |||
past_key_values_length:seq_length | |||
+ past_key_values_length] | |||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs | |||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids | |||
# issue #5664 | |||
if token_type_ids is None: | |||
if hasattr(self, 'token_type_ids'): | |||
buffered_token_type_ids = self.token_type_ids[:, :seq_length] | |||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( | |||
input_shape[0], seq_length) | |||
token_type_ids = buffered_token_type_ids_expanded | |||
else: | |||
token_type_ids = torch.zeros( | |||
input_shape, | |||
dtype=torch.long, | |||
device=self.position_ids.device) | |||
if inputs_embeds is None: | |||
inputs_embeds = self.word_embeddings(input_ids) | |||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||
embeddings = inputs_embeds + token_type_embeddings | |||
if self.position_embedding_type == 'absolute': | |||
position_embeddings = self.position_embeddings(position_ids) | |||
embeddings += position_embeddings | |||
embeddings = self.LayerNorm(embeddings) | |||
embeddings = self.dropout(embeddings) | |||
if not return_inputs_embeds: | |||
return embeddings | |||
else: | |||
return embeddings, inputs_embeds | |||
class SbertSelfAttention(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr( | |||
config, 'embedding_size'): | |||
raise ValueError( | |||
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' | |||
f'heads ({config.num_attention_heads})') | |||
self.num_attention_heads = config.num_attention_heads | |||
self.attention_head_size = int(config.hidden_size | |||
/ config.num_attention_heads) | |||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
self.query = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.key = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.value = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |||
self.position_embedding_type = getattr(config, | |||
'position_embedding_type', | |||
'absolute') | |||
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': | |||
self.max_position_embeddings = config.max_position_embeddings | |||
self.distance_embedding = nn.Embedding( | |||
2 * config.max_position_embeddings - 1, | |||
self.attention_head_size) | |||
self.is_decoder = config.is_decoder | |||
def transpose_for_scores(self, x): | |||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, | |||
self.attention_head_size) | |||
x = x.view(*new_x_shape) | |||
return x.permute(0, 2, 1, 3) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
mixed_query_layer = self.query(hidden_states) | |||
# If this is instantiated as a cross-attention module, the keys | |||
# and values come from an encoder; the attention mask needs to be | |||
# such that the encoder's padding tokens are not attended to. | |||
is_cross_attention = encoder_hidden_states is not None | |||
if is_cross_attention and past_key_value is not None: | |||
# reuse k,v, cross_attentions | |||
key_layer = past_key_value[0] | |||
value_layer = past_key_value[1] | |||
attention_mask = encoder_attention_mask | |||
elif is_cross_attention: | |||
key_layer = self.transpose_for_scores( | |||
self.key(encoder_hidden_states)) | |||
value_layer = self.transpose_for_scores( | |||
self.value(encoder_hidden_states)) | |||
attention_mask = encoder_attention_mask | |||
elif past_key_value is not None: | |||
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) | |||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) | |||
else: | |||
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||
if self.is_decoder: | |||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | |||
# Further calls to cross_attention layer can then reuse all cross-attention | |||
# key/value_states (first "if" case) | |||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | |||
# all previous decoder key/value_states. Further calls to uni-directional self-attention | |||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | |||
# if encoder bi-directional self-attention `past_key_value` is always `None` | |||
past_key_value = (key_layer, value_layer) | |||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||
attention_scores = torch.matmul(query_layer, | |||
key_layer.transpose(-1, -2)) | |||
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': | |||
seq_length = hidden_states.size()[1] | |||
position_ids_l = torch.arange( | |||
seq_length, dtype=torch.long, | |||
device=hidden_states.device).view(-1, 1) | |||
position_ids_r = torch.arange( | |||
seq_length, dtype=torch.long, | |||
device=hidden_states.device).view(1, -1) | |||
distance = position_ids_l - position_ids_r | |||
positional_embedding = self.distance_embedding( | |||
distance + self.max_position_embeddings - 1) | |||
positional_embedding = positional_embedding.to( | |||
dtype=query_layer.dtype) # fp16 compatibility | |||
if self.position_embedding_type == 'relative_key': | |||
relative_position_scores = torch.einsum( | |||
'bhld,lrd->bhlr', query_layer, positional_embedding) | |||
attention_scores = attention_scores + relative_position_scores | |||
elif self.position_embedding_type == 'relative_key_query': | |||
relative_position_scores_query = torch.einsum( | |||
'bhld,lrd->bhlr', query_layer, positional_embedding) | |||
relative_position_scores_key = torch.einsum( | |||
'bhrd,lrd->bhlr', key_layer, positional_embedding) | |||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key | |||
attention_scores = attention_scores / math.sqrt( | |||
self.attention_head_size) | |||
if attention_mask is not None: | |||
# Apply the attention mask is (precomputed for all layers in SbertModel forward() function) | |||
attention_scores = attention_scores + attention_mask | |||
# Normalize the attention scores to probabilities. | |||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||
# This is actually dropping out entire tokens to attend to, which might | |||
# seem a bit unusual, but is taken from the original Transformer paper. | |||
attention_probs = self.dropout(attention_probs) | |||
# Mask heads if we want to | |||
if head_mask is not None: | |||
attention_probs = attention_probs * head_mask | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + ( | |||
self.all_head_size, ) | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
outputs = (context_layer, | |||
attention_probs) if output_attentions else (context_layer, ) | |||
if self.is_decoder: | |||
outputs = outputs + (past_key_value, ) | |||
return outputs | |||
class SbertSelfOutput(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class SbertAttention(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.self = SbertSelfAttention(config) | |||
self.output = SbertSelfOutput(config) | |||
self.pruned_heads = set() | |||
def prune_heads(self, heads): | |||
if len(heads) == 0: | |||
return | |||
heads, index = find_pruneable_heads_and_indices( | |||
heads, self.self.num_attention_heads, | |||
self.self.attention_head_size, self.pruned_heads) | |||
# Prune linear layers | |||
self.self.query = prune_linear_layer(self.self.query, index) | |||
self.self.key = prune_linear_layer(self.self.key, index) | |||
self.self.value = prune_linear_layer(self.self.value, index) | |||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) | |||
# Update hyper params and store pruned heads | |||
self.self.num_attention_heads = self.self.num_attention_heads - len( | |||
heads) | |||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads | |||
self.pruned_heads = self.pruned_heads.union(heads) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
self_outputs = self.self( | |||
hidden_states, | |||
attention_mask, | |||
head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
past_key_value, | |||
output_attentions, | |||
) | |||
attention_output = self.output(self_outputs[0], hidden_states) | |||
outputs = (attention_output, | |||
) + self_outputs[1:] # add attentions if we output them | |||
return outputs | |||
class SbertIntermediate(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |||
if isinstance(config.hidden_act, str): | |||
self.intermediate_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.intermediate_act_fn = config.hidden_act | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.intermediate_act_fn(hidden_states) | |||
return hidden_states | |||
class SbertOutput(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class SbertLayer(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.chunk_size_feed_forward = config.chunk_size_feed_forward | |||
self.seq_len_dim = 1 | |||
self.attention = SbertAttention(config) | |||
self.is_decoder = config.is_decoder | |||
self.add_cross_attention = config.add_cross_attention | |||
if self.add_cross_attention: | |||
if not self.is_decoder: | |||
raise ValueError( | |||
f'{self} should be used as a decoder model if cross attention is added' | |||
) | |||
self.crossattention = SbertAttention(config) | |||
self.intermediate = SbertIntermediate(config) | |||
self.output = SbertOutput(config) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 | |||
self_attn_past_key_value = past_key_value[: | |||
2] if past_key_value is not None else None | |||
self_attention_outputs = self.attention( | |||
hidden_states, | |||
attention_mask, | |||
head_mask, | |||
output_attentions=output_attentions, | |||
past_key_value=self_attn_past_key_value, | |||
) | |||
attention_output = self_attention_outputs[0] | |||
# if decoder, the last output is tuple of self-attn cache | |||
if self.is_decoder: | |||
outputs = self_attention_outputs[1:-1] | |||
present_key_value = self_attention_outputs[-1] | |||
else: | |||
outputs = self_attention_outputs[ | |||
1:] # add self attentions if we output attention weights | |||
cross_attn_present_key_value = None | |||
if self.is_decoder and encoder_hidden_states is not None: | |||
if not hasattr(self, 'crossattention'): | |||
raise ValueError( | |||
f'If `encoder_hidden_states` are passed, {self} has to be instantiated' | |||
f'with cross-attention layers by setting `config.add_cross_attention=True`' | |||
) | |||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple | |||
cross_attn_past_key_value = past_key_value[ | |||
-2:] if past_key_value is not None else None | |||
cross_attention_outputs = self.crossattention( | |||
attention_output, | |||
attention_mask, | |||
head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
cross_attn_past_key_value, | |||
output_attentions, | |||
) | |||
attention_output = cross_attention_outputs[0] | |||
outputs = outputs + cross_attention_outputs[ | |||
1:-1] # add cross attentions if we output attention weights | |||
# add cross-attn cache to positions 3,4 of present_key_value tuple | |||
cross_attn_present_key_value = cross_attention_outputs[-1] | |||
present_key_value = present_key_value + cross_attn_present_key_value | |||
layer_output = apply_chunking_to_forward(self.feed_forward_chunk, | |||
self.chunk_size_feed_forward, | |||
self.seq_len_dim, | |||
attention_output) | |||
outputs = (layer_output, ) + outputs | |||
# if decoder, return the attn key/values as the last output | |||
if self.is_decoder: | |||
outputs = outputs + (present_key_value, ) | |||
return outputs | |||
def feed_forward_chunk(self, attention_output): | |||
intermediate_output = self.intermediate(attention_output) | |||
layer_output = self.output(intermediate_output, attention_output) | |||
return layer_output | |||
class SbertEncoder(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.config = config | |||
self.layer = nn.ModuleList( | |||
[SbertLayer(config) for _ in range(config.num_hidden_layers)]) | |||
self.gradient_checkpointing = False | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=False, | |||
output_hidden_states=False, | |||
return_dict=True, | |||
): | |||
all_hidden_states = () if output_hidden_states else None | |||
all_self_attentions = () if output_attentions else None | |||
all_cross_attentions = ( | |||
) if output_attentions and self.config.add_cross_attention else None | |||
next_decoder_cache = () if use_cache else None | |||
for i, layer_module in enumerate(self.layer): | |||
if output_hidden_states: | |||
all_hidden_states = all_hidden_states + (hidden_states, ) | |||
layer_head_mask = head_mask[i] if head_mask is not None else None | |||
past_key_value = past_key_values[ | |||
i] if past_key_values is not None else None | |||
if self.gradient_checkpointing and self.training: | |||
if use_cache: | |||
logger.warning( | |||
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' | |||
) | |||
use_cache = False | |||
def create_custom_forward(module): | |||
def custom_forward(*inputs): | |||
return module(*inputs, past_key_value, | |||
output_attentions) | |||
return custom_forward | |||
layer_outputs = torch.utils.checkpoint.checkpoint( | |||
create_custom_forward(layer_module), | |||
hidden_states, | |||
attention_mask, | |||
layer_head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
) | |||
else: | |||
layer_outputs = layer_module( | |||
hidden_states, | |||
attention_mask, | |||
layer_head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
past_key_value, | |||
output_attentions, | |||
) | |||
hidden_states = layer_outputs[0] | |||
if use_cache: | |||
next_decoder_cache += (layer_outputs[-1], ) | |||
if output_attentions: | |||
all_self_attentions = all_self_attentions + ( | |||
layer_outputs[1], ) | |||
if self.config.add_cross_attention: | |||
all_cross_attentions = all_cross_attentions + ( | |||
layer_outputs[2], ) | |||
if output_hidden_states: | |||
all_hidden_states = all_hidden_states + (hidden_states, ) | |||
if not return_dict: | |||
return tuple(v for v in [ | |||
hidden_states, | |||
next_decoder_cache, | |||
all_hidden_states, | |||
all_self_attentions, | |||
all_cross_attentions, | |||
] if v is not None) | |||
return BaseModelOutputWithPastAndCrossAttentions( | |||
last_hidden_state=hidden_states, | |||
past_key_values=next_decoder_cache, | |||
hidden_states=all_hidden_states, | |||
attentions=all_self_attentions, | |||
cross_attentions=all_cross_attentions, | |||
) | |||
class SbertPooler(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
self.activation = nn.Tanh() | |||
def forward(self, hidden_states): | |||
# We "pool" the model by simply taking the hidden state corresponding | |||
# to the first token. | |||
first_token_tensor = hidden_states[:, 0] | |||
pooled_output = self.dense(first_token_tensor) | |||
pooled_output = self.activation(pooled_output) | |||
return pooled_output | |||
@dataclass | |||
class SbertForPreTrainingOutput(ModelOutput): | |||
""" | |||
Output type of :class:`~structbert.utils.BertForPreTraining`. | |||
Args: | |||
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): | |||
Total loss as the sum of the masked language modeling loss and the next sequence prediction | |||
(classification) loss. | |||
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): | |||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |||
seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): | |||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation | |||
before SoftMax). | |||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when | |||
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |||
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | |||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when | |||
``output_attentions=True`` is passed or when ``config.output_attentions=True``): | |||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, | |||
sequence_length, sequence_length)`. | |||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |||
heads. | |||
""" | |||
loss: Optional[torch.FloatTensor] = None | |||
prediction_logits: torch.FloatTensor = None | |||
seq_relationship_logits: torch.FloatTensor = None | |||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |||
attentions: Optional[Tuple[torch.FloatTensor]] = None | |||
@dataclass | |||
class BaseModelOutputWithPoolingAndCrossAttentionsWithEmbedding( | |||
BaseModelOutputWithPoolingAndCrossAttentions): | |||
embedding_output: torch.FloatTensor = None | |||
logits: Optional[Union[tuple, torch.FloatTensor]] = None | |||
kwargs: dict = None |
@@ -0,0 +1,3 @@ | |||
from .sequence_classification_head import SequenceClassificationHead | |||
__all__ = ['SequenceClassificationHead'] |
@@ -0,0 +1,44 @@ | |||
import importlib | |||
from typing import Dict, List, Optional, Union | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from ....metainfo import Heads | |||
from ....outputs import OutputKeys | |||
from ....utils.constant import Tasks | |||
from ...base import TorchHead | |||
from ...builder import HEADS | |||
@HEADS.register_module( | |||
Tasks.text_classification, module_name=Heads.text_classification) | |||
class SequenceClassificationHead(TorchHead): | |||
def __init__(self, **kwargs): | |||
super().__init__(**kwargs) | |||
config = self.config | |||
self.num_labels = config.num_labels | |||
self.config = config | |||
classifier_dropout = ( | |||
config['classifier_dropout'] if config.get('classifier_dropout') | |||
is not None else config['hidden_dropout_prob']) | |||
self.dropout = nn.Dropout(classifier_dropout) | |||
self.classifier = nn.Linear(config['hidden_size'], | |||
config['num_labels']) | |||
def forward(self, inputs=None): | |||
if isinstance(inputs, dict): | |||
assert inputs.get('pooled_output') is not None | |||
pooled_output = inputs.get('pooled_output') | |||
else: | |||
pooled_output = inputs | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
return {OutputKeys.LOGITS: logits} | |||
def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
labels) -> Dict[str, torch.Tensor]: | |||
logits = outputs[OutputKeys.LOGITS] | |||
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} |
@@ -2,8 +2,7 @@ from typing import Dict | |||
from ...metainfo import Models | |||
from ...utils.constant import Tasks | |||
from ..base import Tensor | |||
from ..base_torch import TorchModel | |||
from ..base import Tensor, TorchModel | |||
from ..builder import MODELS | |||
__all__ = ['PalmForTextGeneration'] | |||
@@ -42,6 +42,9 @@ class SbertTextClassfier(SbertPreTrainedModel): | |||
return {'logits': logits, 'loss': loss} | |||
return {'logits': logits} | |||
def build(**kwags): | |||
return SbertTextClassfier.from_pretrained(model_dir, **model_args) | |||
class SbertForSequenceClassificationBase(Model): | |||
@@ -0,0 +1,85 @@ | |||
import os | |||
from typing import Any, Dict | |||
import json | |||
import numpy as np | |||
from ...metainfo import TaskModels | |||
from ...outputs import OutputKeys | |||
from ...utils.constant import Tasks | |||
from ..builder import MODELS | |||
from .task_model import SingleBackboneTaskModelBase | |||
__all__ = ['SequenceClassificationModel'] | |||
@MODELS.register_module( | |||
Tasks.sentiment_classification, module_name=TaskModels.text_classification) | |||
@MODELS.register_module( | |||
Tasks.text_classification, module_name=TaskModels.text_classification) | |||
class SequenceClassificationModel(SingleBackboneTaskModelBase): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the sequence classification model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
if 'base_model_prefix' in kwargs: | |||
self._base_model_prefix = kwargs['base_model_prefix'] | |||
backbone_cfg = self.cfg.backbone | |||
head_cfg = self.cfg.head | |||
# get the num_labels from label_mapping.json | |||
self.id2label = {} | |||
self.label_path = os.path.join(model_dir, 'label_mapping.json') | |||
if os.path.exists(self.label_path): | |||
with open(self.label_path) as f: | |||
self.label_mapping = json.load(f) | |||
self.id2label = { | |||
idx: name | |||
for name, idx in self.label_mapping.items() | |||
} | |||
head_cfg['num_labels'] = len(self.label_mapping) | |||
self.build_backbone(backbone_cfg) | |||
self.build_head(head_cfg) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
outputs = super().forward(input) | |||
sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||
outputs = self.head.forward(pooled_output) | |||
if 'labels' in input: | |||
loss = self.compute_loss(outputs, input['labels']) | |||
outputs.update(loss) | |||
return outputs | |||
def extract_logits(self, outputs): | |||
return outputs[OutputKeys.LOGITS].cpu().detach() | |||
def extract_backbone_outputs(self, outputs): | |||
sequence_output = None | |||
pooled_output = None | |||
if hasattr(self.backbone, 'extract_sequence_outputs'): | |||
sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||
if hasattr(self.backbone, 'extract_pooled_outputs'): | |||
pooled_output = self.backbone.extract_pooled_outputs(outputs) | |||
return sequence_output, pooled_output | |||
def compute_loss(self, outputs, labels): | |||
loss = self.head.compute_loss(outputs, labels) | |||
return loss | |||
def postprocess(self, input, **kwargs): | |||
logits = self.extract_logits(input) | |||
probs = logits.softmax(-1).numpy() | |||
pred = logits.argmax(-1).numpy() | |||
logits = logits.numpy() | |||
res = { | |||
OutputKeys.PREDICTIONS: pred, | |||
OutputKeys.PROBABILITIES: probs, | |||
OutputKeys.LOGITS: logits | |||
} | |||
return res |
@@ -3,15 +3,14 @@ | |||
import os | |||
from typing import Any, Dict | |||
from ....metainfo import Models | |||
from ....preprocessors.space.fields.intent_field import IntentBPETextField | |||
from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||
from ....utils.config import Config | |||
from ....utils.constant import ModelFile, Tasks | |||
from ...base import Model, Tensor | |||
from ...builder import MODELS | |||
from .model.generator import Generator | |||
from .model.model_base import SpaceModelBase | |||
from ...metainfo import Models | |||
from ...preprocessors.space.fields.intent_field import IntentBPETextField | |||
from ...trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||
from ...utils.config import Config | |||
from ...utils.constant import ModelFile, Tasks | |||
from ..base import Model, Tensor | |||
from ..builder import MODELS | |||
from .backbones import SpaceGenerator, SpaceModelBase | |||
__all__ = ['SpaceForDialogIntent'] | |||
@@ -37,7 +36,8 @@ class SpaceForDialogIntent(Model): | |||
'text_field', | |||
IntentBPETextField(self.model_dir, config=self.config)) | |||
self.generator = Generator.create(self.config, reader=self.text_field) | |||
self.generator = SpaceGenerator.create( | |||
self.config, reader=self.text_field) | |||
self.model = SpaceModelBase.create( | |||
model_dir=model_dir, | |||
config=self.config, |
@@ -3,15 +3,14 @@ | |||
import os | |||
from typing import Any, Dict, Optional | |||
from ....metainfo import Models | |||
from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField | |||
from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||
from ....utils.config import Config | |||
from ....utils.constant import ModelFile, Tasks | |||
from ...base import Model, Tensor | |||
from ...builder import MODELS | |||
from .model.generator import Generator | |||
from .model.model_base import SpaceModelBase | |||
from ...metainfo import Models | |||
from ...preprocessors.space.fields.gen_field import MultiWOZBPETextField | |||
from ...trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||
from ...utils.config import Config | |||
from ...utils.constant import ModelFile, Tasks | |||
from ..base import Model, Tensor | |||
from ..builder import MODELS | |||
from .backbones import SpaceGenerator, SpaceModelBase | |||
__all__ = ['SpaceForDialogModeling'] | |||
@@ -35,7 +34,8 @@ class SpaceForDialogModeling(Model): | |||
self.text_field = kwargs.pop( | |||
'text_field', | |||
MultiWOZBPETextField(self.model_dir, config=self.config)) | |||
self.generator = Generator.create(self.config, reader=self.text_field) | |||
self.generator = SpaceGenerator.create( | |||
self.config, reader=self.text_field) | |||
self.model = SpaceModelBase.create( | |||
model_dir=model_dir, | |||
config=self.config, |
@@ -2,10 +2,10 @@ import os | |||
from typing import Any, Dict | |||
from modelscope.utils.constant import Tasks | |||
from ....metainfo import Models | |||
from ....utils.nlp.space.utils_dst import batch_to_device | |||
from ...base import Model, Tensor | |||
from ...builder import MODELS | |||
from ...metainfo import Models | |||
from ...utils.nlp.space.utils_dst import batch_to_device | |||
from ..base import Model, Tensor | |||
from ..builder import MODELS | |||
__all__ = ['SpaceForDialogStateTracking'] | |||
@@ -0,0 +1,489 @@ | |||
import os.path | |||
import re | |||
from abc import ABC | |||
from collections import OrderedDict | |||
from typing import Any, Dict | |||
import torch | |||
from torch import nn | |||
from ...utils.config import ConfigDict | |||
from ...utils.constant import Fields, Tasks | |||
from ...utils.logger import get_logger | |||
from ...utils.utils import if_func_recieve_dict_inputs | |||
from ..base import TorchModel | |||
from ..builder import build_backbone, build_head | |||
logger = get_logger(__name__) | |||
__all__ = ['EncoderDecoderTaskModelBase', 'SingleBackboneTaskModelBase'] | |||
def _repr(modules, depth=1): | |||
# model name log level control | |||
if depth == 0: | |||
return modules._get_name() | |||
# We treat the extra repr like the sub-module, one item per line | |||
extra_lines = [] | |||
extra_repr = modules.extra_repr() | |||
# empty string will be split into list [''] | |||
if extra_repr: | |||
extra_lines = extra_repr.split('\n') | |||
child_lines = [] | |||
def _addindent(s_, numSpaces): | |||
s = s_.split('\n') | |||
# don't do anything for single-line stuff | |||
if len(s) == 1: | |||
return s_ | |||
first = s.pop(0) | |||
s = [(numSpaces * ' ') + line for line in s] | |||
s = '\n'.join(s) | |||
s = first + '\n' + s | |||
return s | |||
for key, module in modules._modules.items(): | |||
mod_str = _repr(module, depth - 1) | |||
mod_str = _addindent(mod_str, 2) | |||
child_lines.append('(' + key + '): ' + mod_str) | |||
lines = extra_lines + child_lines | |||
main_str = modules._get_name() + '(' | |||
if lines: | |||
# simple one-liner info, which most builtin Modules will use | |||
if len(extra_lines) == 1 and not child_lines: | |||
main_str += extra_lines[0] | |||
else: | |||
main_str += '\n ' + '\n '.join(lines) + '\n' | |||
main_str += ')' | |||
return main_str | |||
class BaseTaskModel(TorchModel, ABC): | |||
""" Base task model interface for nlp | |||
""" | |||
# keys to ignore when load missing | |||
_keys_to_ignore_on_load_missing = None | |||
# keys to ignore when load unexpected | |||
_keys_to_ignore_on_load_unexpected = None | |||
# backbone prefix, default None | |||
_backbone_prefix = None | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.cfg = ConfigDict(kwargs) | |||
def __repr__(self): | |||
# only log backbone and head name | |||
depth = 1 | |||
return _repr(self, depth) | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model_dir = kwargs.get('model_dir') | |||
model = cls(**kwargs) | |||
model.load_checkpoint(model_local_dir=model_dir, **kwargs) | |||
return model | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
pass | |||
def load_checkpoint(self, | |||
model_local_dir, | |||
default_dtype=None, | |||
load_state_fn=None, | |||
**kwargs): | |||
""" | |||
Load model checkpoint file and feed the parameters into the model. | |||
Args: | |||
model_local_dir: The actual checkpoint dir on local disk. | |||
default_dtype: Set the default float type by 'torch.set_default_dtype' | |||
load_state_fn: An optional load_state_fn used to load state_dict into the model. | |||
Returns: | |||
""" | |||
# TODO Sharded ckpt | |||
ckpt_file = os.path.join(model_local_dir, 'pytorch_model.bin') | |||
state_dict = torch.load(ckpt_file, map_location='cpu') | |||
if default_dtype is not None: | |||
torch.set_default_dtype(default_dtype) | |||
missing_keys, unexpected_keys, mismatched_keys, error_msgs = self._load_checkpoint( | |||
state_dict, | |||
load_state_fn=load_state_fn, | |||
ignore_mismatched_sizes=True, | |||
_fast_init=True, | |||
) | |||
return { | |||
'missing_keys': missing_keys, | |||
'unexpected_keys': unexpected_keys, | |||
'mismatched_keys': mismatched_keys, | |||
'error_msgs': error_msgs, | |||
} | |||
def _load_checkpoint( | |||
self, | |||
state_dict, | |||
load_state_fn, | |||
ignore_mismatched_sizes, | |||
_fast_init, | |||
): | |||
# Retrieve missing & unexpected_keys | |||
model_state_dict = self.state_dict() | |||
prefix = self._backbone_prefix | |||
# add head prefix | |||
new_state_dict = OrderedDict() | |||
for name, module in state_dict.items(): | |||
if not name.startswith(prefix) and not name.startswith('head'): | |||
new_state_dict['.'.join(['head', name])] = module | |||
else: | |||
new_state_dict[name] = module | |||
state_dict = new_state_dict | |||
loaded_keys = [k for k in state_dict.keys()] | |||
expected_keys = list(model_state_dict.keys()) | |||
def _fix_key(key): | |||
if 'beta' in key: | |||
return key.replace('beta', 'bias') | |||
if 'gamma' in key: | |||
return key.replace('gamma', 'weight') | |||
return key | |||
original_loaded_keys = loaded_keys | |||
loaded_keys = [_fix_key(key) for key in loaded_keys] | |||
if len(prefix) > 0: | |||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) | |||
expects_prefix_module = any( | |||
s.startswith(prefix) for s in expected_keys) | |||
else: | |||
has_prefix_module = False | |||
expects_prefix_module = False | |||
# key re-naming operations are never done on the keys | |||
# that are loaded, but always on the keys of the newly initialized model | |||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module | |||
add_prefix_to_model = has_prefix_module and not expects_prefix_module | |||
if remove_prefix_from_model: | |||
expected_keys_not_prefixed = [ | |||
s for s in expected_keys if not s.startswith(prefix) | |||
] | |||
expected_keys = [ | |||
'.'.join(s.split('.')[1:]) if s.startswith(prefix) else s | |||
for s in expected_keys | |||
] | |||
elif add_prefix_to_model: | |||
expected_keys = ['.'.join([prefix, s]) for s in expected_keys] | |||
missing_keys = list(set(expected_keys) - set(loaded_keys)) | |||
unexpected_keys = list(set(loaded_keys) - set(expected_keys)) | |||
if self._keys_to_ignore_on_load_missing is not None: | |||
for pat in self._keys_to_ignore_on_load_missing: | |||
missing_keys = [ | |||
k for k in missing_keys if re.search(pat, k) is None | |||
] | |||
if self._keys_to_ignore_on_load_unexpected is not None: | |||
for pat in self._keys_to_ignore_on_load_unexpected: | |||
unexpected_keys = [ | |||
k for k in unexpected_keys if re.search(pat, k) is None | |||
] | |||
if _fast_init: | |||
# retrieve unintialized modules and initialize | |||
uninitialized_modules = self.retrieve_modules_from_names( | |||
missing_keys, | |||
prefix=prefix, | |||
add_prefix=add_prefix_to_model, | |||
remove_prefix=remove_prefix_from_model) | |||
for module in uninitialized_modules: | |||
self._init_weights(module) | |||
# Make sure we are able to load base models as well as derived models (with heads) | |||
start_prefix = '' | |||
model_to_load = self | |||
if len(prefix) > 0 and not hasattr(self, prefix) and has_prefix_module: | |||
start_prefix = prefix + '.' | |||
if len(prefix) > 0 and hasattr(self, prefix) and not has_prefix_module: | |||
model_to_load = getattr(self, prefix) | |||
if any(key in expected_keys_not_prefixed for key in loaded_keys): | |||
raise ValueError( | |||
'The state dictionary of the model you are trying to load is corrupted. Are you sure it was ' | |||
'properly saved?') | |||
def _find_mismatched_keys( | |||
state_dict, | |||
model_state_dict, | |||
loaded_keys, | |||
add_prefix_to_model, | |||
remove_prefix_from_model, | |||
ignore_mismatched_sizes, | |||
): | |||
mismatched_keys = [] | |||
if ignore_mismatched_sizes: | |||
for checkpoint_key in loaded_keys: | |||
model_key = checkpoint_key | |||
if remove_prefix_from_model: | |||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | |||
model_key = f'{prefix}.{checkpoint_key}' | |||
elif add_prefix_to_model: | |||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | |||
model_key = '.'.join(checkpoint_key.split('.')[1:]) | |||
if (model_key in model_state_dict): | |||
model_shape = model_state_dict[model_key].shape | |||
checkpoint_shape = state_dict[checkpoint_key].shape | |||
if (checkpoint_shape != model_shape): | |||
mismatched_keys.append( | |||
(checkpoint_key, | |||
state_dict[checkpoint_key].shape, | |||
model_state_dict[model_key].shape)) | |||
del state_dict[checkpoint_key] | |||
return mismatched_keys | |||
def _load_state_dict_into_model(model_to_load, state_dict, | |||
start_prefix): | |||
# Convert old format to new format if needed from a PyTorch state_dict | |||
old_keys = [] | |||
new_keys = [] | |||
for key in state_dict.keys(): | |||
new_key = None | |||
if 'gamma' in key: | |||
new_key = key.replace('gamma', 'weight') | |||
if 'beta' in key: | |||
new_key = key.replace('beta', 'bias') | |||
if new_key: | |||
old_keys.append(key) | |||
new_keys.append(new_key) | |||
for old_key, new_key in zip(old_keys, new_keys): | |||
state_dict[new_key] = state_dict.pop(old_key) | |||
# copy state_dict so _load_from_state_dict can modify it | |||
metadata = getattr(state_dict, '_metadata', None) | |||
state_dict = state_dict.copy() | |||
if metadata is not None: | |||
state_dict._metadata = metadata | |||
error_msgs = [] | |||
if load_state_fn is not None: | |||
load_state_fn( | |||
model_to_load, | |||
state_dict, | |||
prefix=start_prefix, | |||
local_metadata=None, | |||
error_msgs=error_msgs) | |||
else: | |||
def load(module: nn.Module, prefix=''): | |||
local_metadata = {} if metadata is None else metadata.get( | |||
prefix[:-1], {}) | |||
args = (state_dict, prefix, local_metadata, True, [], [], | |||
error_msgs) | |||
module._load_from_state_dict(*args) | |||
for name, child in module._modules.items(): | |||
if child is not None: | |||
load(child, prefix + name + '.') | |||
load(model_to_load, prefix=start_prefix) | |||
return error_msgs | |||
# Whole checkpoint | |||
mismatched_keys = _find_mismatched_keys( | |||
state_dict, | |||
model_state_dict, | |||
original_loaded_keys, | |||
add_prefix_to_model, | |||
remove_prefix_from_model, | |||
ignore_mismatched_sizes, | |||
) | |||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, | |||
start_prefix) | |||
if len(error_msgs) > 0: | |||
error_msg = '\n\t'.join(error_msgs) | |||
raise RuntimeError( | |||
f'Error(s) in loading state_dict for {self.__class__.__name__}:\n\t{error_msg}' | |||
) | |||
if len(unexpected_keys) > 0: | |||
logger.warning( | |||
f'Some weights of the model checkpoint were not used when' | |||
f' initializing {self.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are' | |||
f' initializing {self.__class__.__name__} from the checkpoint of a model trained on another task or' | |||
' with another architecture (e.g. initializing a BertForSequenceClassification model from a' | |||
' BertForPreTraining model).\n- This IS NOT expected if you are initializing' | |||
f' {self.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical' | |||
' (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).' | |||
) | |||
else: | |||
logger.info( | |||
f'All model checkpoint weights were used when initializing {self.__class__.__name__}.\n' | |||
) | |||
if len(missing_keys) > 0: | |||
logger.warning( | |||
f'Some weights of {self.__class__.__name__} were not initialized from the model checkpoint' | |||
f' and are newly initialized: {missing_keys}\nYou should probably' | |||
' TRAIN this model on a down-stream task to be able to use it for predictions and inference.' | |||
) | |||
elif len(mismatched_keys) == 0: | |||
logger.info( | |||
f'All the weights of {self.__class__.__name__} were initialized from the model checkpoint ' | |||
f'If your task is similar to the task the model of the checkpoint' | |||
f' was trained on, you can already use {self.__class__.__name__} for predictions without further' | |||
' training.') | |||
if len(mismatched_keys) > 0: | |||
mismatched_warning = '\n'.join([ | |||
f'- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated' | |||
for key, shape1, shape2 in mismatched_keys | |||
]) | |||
logger.warning( | |||
f'Some weights of {self.__class__.__name__} were not initialized from the model checkpoint' | |||
f' and are newly initialized because the shapes did not' | |||
f' match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able' | |||
' to use it for predictions and inference.') | |||
return missing_keys, unexpected_keys, mismatched_keys, error_msgs | |||
def retrieve_modules_from_names(self, | |||
names, | |||
prefix=None, | |||
add_prefix=False, | |||
remove_prefix=False): | |||
module_keys = set(['.'.join(key.split('.')[:-1]) for key in names]) | |||
# torch.nn.ParameterList is a special case where two parameter keywords | |||
# are appended to the module name, *e.g.* bert.special_embeddings.0 | |||
module_keys = module_keys.union( | |||
set([ | |||
'.'.join(key.split('.')[:-2]) for key in names | |||
if key[-1].isdigit() | |||
])) | |||
retrieved_modules = [] | |||
# retrieve all modules that has at least one missing weight name | |||
for name, module in self.named_modules(): | |||
if remove_prefix: | |||
name = '.'.join( | |||
name.split('.')[1:]) if name.startswith(prefix) else name | |||
elif add_prefix: | |||
name = '.'.join([prefix, name]) if len(name) > 0 else prefix | |||
if name in module_keys: | |||
retrieved_modules.append(module) | |||
return retrieved_modules | |||
class SingleBackboneTaskModelBase(BaseTaskModel): | |||
""" | |||
This is the base class of any single backbone nlp task classes. | |||
""" | |||
# The backbone prefix defaults to "bert" | |||
_backbone_prefix = 'bert' | |||
# The head prefix defaults to "head" | |||
_head_prefix = 'head' | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
super().__init__(model_dir, *args, **kwargs) | |||
def build_backbone(self, cfg): | |||
if 'prefix' in cfg: | |||
self._backbone_prefix = cfg['prefix'] | |||
backbone = build_backbone(cfg, field=Fields.nlp) | |||
setattr(self, cfg['prefix'], backbone) | |||
def build_head(self, cfg): | |||
if 'prefix' in cfg: | |||
self._head_prefix = cfg['prefix'] | |||
head = build_head(cfg) | |||
setattr(self, self._head_prefix, head) | |||
return head | |||
@property | |||
def backbone(self): | |||
if 'backbone' != self._backbone_prefix: | |||
return getattr(self, self._backbone_prefix) | |||
return super().__getattr__('backbone') | |||
@property | |||
def head(self): | |||
if 'head' != self._head_prefix: | |||
return getattr(self, self._head_prefix) | |||
return super().__getattr__('head') | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
"""default forward method is the backbone-only forward""" | |||
if if_func_recieve_dict_inputs(self.backbone.forward, input): | |||
outputs = self.backbone.forward(input) | |||
else: | |||
outputs = self.backbone.forward(**input) | |||
return outputs | |||
class EncoderDecoderTaskModelBase(BaseTaskModel): | |||
""" | |||
This is the base class of encoder-decoder nlp task classes. | |||
""" | |||
# The encoder backbone prefix, default to "encoder" | |||
_encoder_prefix = 'encoder' | |||
# The decoder backbone prefix, default to "decoder" | |||
_decoder_prefix = 'decoder' | |||
# The key in cfg specifing the encoder type | |||
_encoder_key_in_cfg = 'encoder_type' | |||
# The key in cfg specifing the decoder type | |||
_decoder_key_in_cfg = 'decoder_type' | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
super().__init__(model_dir, *args, **kwargs) | |||
def build_encoder(self): | |||
encoder = build_backbone( | |||
self.cfg, | |||
type_name=self._encoder_key_in_cfg, | |||
task_name=Tasks.backbone) | |||
setattr(self, self._encoder_prefix, encoder) | |||
return encoder | |||
def build_decoder(self): | |||
decoder = build_backbone( | |||
self.cfg, | |||
type_name=self._decoder_key_in_cfg, | |||
task_name=Tasks.backbone) | |||
setattr(self, self._decoder_prefix, decoder) | |||
return decoder | |||
@property | |||
def encoder_(self): | |||
return getattr(self, self._encoder_prefix) | |||
@property | |||
def decoder_(self): | |||
return getattr(self, self._decoder_prefix) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
if if_func_recieve_dict_inputs(self.encoder_.forward, input): | |||
encoder_outputs = self.encoder_.forward(input) | |||
else: | |||
encoder_outputs = self.encoder_.forward(**input) | |||
decoder_inputs = self.project_decoder_inputs_and_mediate( | |||
input, encoder_outputs) | |||
if if_func_recieve_dict_inputs(self.decoder_.forward, input): | |||
outputs = self.decoder_.forward(decoder_inputs) | |||
else: | |||
outputs = self.decoder_.forward(**decoder_inputs) | |||
return outputs | |||
def project_decoder_inputs_and_mediate(self, input, encoder_outputs): | |||
return {**input, **encoder_outputs} |
@@ -4,6 +4,7 @@ from modelscope.utils.constant import Tasks | |||
class OutputKeys(object): | |||
LOSS = 'loss' | |||
LOGITS = 'logits' | |||
SCORES = 'scores' | |||
LABEL = 'label' | |||
@@ -22,6 +23,8 @@ class OutputKeys(object): | |||
TRANSLATION = 'translation' | |||
RESPONSE = 'response' | |||
PREDICTION = 'prediction' | |||
PREDICTIONS = 'predictions' | |||
PROBABILITIES = 'probabilities' | |||
DIALOG_STATES = 'dialog_states' | |||
VIDEO_EMBEDDING = 'video_embedding' | |||
@@ -30,7 +30,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | |||
Tasks.sentiment_classification: | |||
(Pipelines.sentiment_classification, | |||
'damo/nlp_structbert_sentiment-classification_chinese-base'), | |||
'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
), # TODO: revise back after passing the pr | |||
Tasks.image_matting: (Pipelines.image_matting, | |||
'damo/cv_unet_image-matting'), | |||
Tasks.text_classification: (Pipelines.sentiment_analysis, | |||
@@ -2,10 +2,10 @@ | |||
from typing import Any, Dict, Union | |||
from modelscope.outputs import OutputKeys | |||
from ...metainfo import Pipelines | |||
from ...models import Model | |||
from ...models.nlp import SpaceForDialogIntent | |||
from ...outputs import OutputKeys | |||
from ...preprocessors import DialogIntentPredictionPreprocessor | |||
from ...utils.constant import Tasks | |||
from ..base import Pipeline | |||
@@ -2,10 +2,10 @@ | |||
from typing import Dict, Union | |||
from modelscope.outputs import OutputKeys | |||
from ...metainfo import Pipelines | |||
from ...models import Model | |||
from ...models.nlp import SpaceForDialogModeling | |||
from ...outputs import OutputKeys | |||
from ...preprocessors import DialogModelingPreprocessor | |||
from ...utils.constant import Tasks | |||
from ..base import Pipeline, Tensor | |||
@@ -1,8 +1,8 @@ | |||
from typing import Any, Dict, Union | |||
from modelscope.outputs import OutputKeys | |||
from ...metainfo import Pipelines | |||
from ...models import Model, SpaceForDialogStateTracking | |||
from ...outputs import OutputKeys | |||
from ...preprocessors import DialogStateTrackingPreprocessor | |||
from ...utils.constant import Tasks | |||
from ..base import Pipeline | |||
@@ -6,7 +6,7 @@ import torch | |||
from modelscope.outputs import OutputKeys | |||
from ...metainfo import Pipelines | |||
from ...models import Model | |||
from ...models.nlp import SbertForSentimentClassification | |||
from ...models.nlp import SequenceClassificationModel | |||
from ...preprocessors import SentimentClassificationPreprocessor | |||
from ...utils.constant import Tasks | |||
from ..base import Pipeline | |||
@@ -21,7 +21,7 @@ __all__ = ['SentimentClassificationPipeline'] | |||
class SentimentClassificationPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[SbertForSentimentClassification, str], | |||
model: Union[SequenceClassificationModel, str], | |||
preprocessor: SentimentClassificationPreprocessor = None, | |||
first_sequence='first_sequence', | |||
second_sequence='second_sequence', | |||
@@ -29,14 +29,14 @@ class SentimentClassificationPipeline(Pipeline): | |||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
Args: | |||
model (SbertForSentimentClassification): a model instance | |||
model (SequenceClassificationModel): a model instance | |||
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | |||
""" | |||
assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \ | |||
'model must be a single str or SbertForSentimentClassification' | |||
assert isinstance(model, str) or isinstance(model, SequenceClassificationModel), \ | |||
'model must be a single str or SentimentClassification' | |||
model = model if isinstance( | |||
model, | |||
SbertForSentimentClassification) else Model.from_pretrained(model) | |||
SequenceClassificationModel) else Model.from_pretrained(model) | |||
if preprocessor is None: | |||
preprocessor = SentimentClassificationPreprocessor( | |||
model.model_dir, | |||
@@ -3,12 +3,23 @@ | |||
from abc import ABC, abstractmethod | |||
from typing import Any, Dict | |||
from modelscope.utils.constant import ModeKeys | |||
class Preprocessor(ABC): | |||
def __init__(self, *args, **kwargs): | |||
self._mode = ModeKeys.INFERENCE | |||
pass | |||
@abstractmethod | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
pass | |||
@property | |||
def mode(self): | |||
return self._mode | |||
@mode.setter | |||
def mode(self, value): | |||
self._mode = value |
@@ -7,7 +7,7 @@ from transformers import AutoTokenizer | |||
from ..metainfo import Preprocessors | |||
from ..models import Model | |||
from ..utils.constant import Fields, InputFields | |||
from ..utils.constant import Fields, InputFields, ModeKeys | |||
from ..utils.hub import parse_label_mapping | |||
from ..utils.type_assert import type_assert | |||
from .base import Preprocessor | |||
@@ -52,6 +52,7 @@ class NLPPreprocessorBase(Preprocessor): | |||
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||
self.tokenize_kwargs = kwargs | |||
self.tokenizer = self.build_tokenizer(model_dir) | |||
self.label2id = parse_label_mapping(self.model_dir) | |||
def build_tokenizer(self, model_dir): | |||
from sofa import SbertTokenizer | |||
@@ -83,7 +84,12 @@ class NLPPreprocessorBase(Preprocessor): | |||
text_a = data.get(self.first_sequence) | |||
text_b = data.get(self.second_sequence, None) | |||
return self.tokenizer(text_a, text_b, **self.tokenize_kwargs) | |||
rst = self.tokenizer(text_a, text_b, **self.tokenize_kwargs) | |||
if self._mode == ModeKeys.TRAIN: | |||
rst = {k: v.squeeze() for k, v in rst.items()} | |||
if self.label2id is not None and 'label' in data: | |||
rst['label'] = self.label2id[str(data['label'])] | |||
return rst | |||
@PREPROCESSORS.register_module( | |||
@@ -200,16 +206,6 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
kwargs['padding'] = 'max_length' | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.label2id = parse_label_mapping(self.model_dir) | |||
@type_assert(object, (str, tuple, Dict)) | |||
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: | |||
rst = super().__call__(data) | |||
rst = {k: v.squeeze() for k, v in rst.items()} | |||
if self.label2id is not None and 'label' in data: | |||
rst['labels'] = [] | |||
rst['labels'].append(self.label2id[str(data['label'])]) | |||
return rst | |||
@PREPROCESSORS.register_module( | |||
@@ -2,6 +2,7 @@ | |||
import os.path | |||
import random | |||
import time | |||
from collections.abc import Mapping | |||
from distutils.version import LooseVersion | |||
from functools import partial | |||
from typing import Callable, List, Optional, Tuple, Union | |||
@@ -16,8 +17,7 @@ from torch.utils.data.distributed import DistributedSampler | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metrics import build_metric, task_default_metrics | |||
from modelscope.models.base import Model | |||
from modelscope.models.base_torch import TorchModel | |||
from modelscope.models.base import Model, TorchModel | |||
from modelscope.msdatasets.ms_dataset import MsDataset | |||
from modelscope.preprocessors import build_preprocessor | |||
from modelscope.preprocessors.base import Preprocessor | |||
@@ -26,12 +26,13 @@ from modelscope.trainers.hooks.priority import Priority, get_priority | |||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | |||
from modelscope.trainers.optimizer.builder import build_optimizer | |||
from modelscope.utils.config import Config, ConfigDict | |||
from modelscope.utils.constant import (Hubs, ModeKeys, ModelFile, Tasks, | |||
TrainerStages) | |||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Hubs, ModeKeys, | |||
ModelFile, Tasks, TrainerStages) | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.registry import build_from_cfg | |||
from modelscope.utils.tensor_utils import torch_default_data_collator | |||
from modelscope.utils.torch_utils import get_dist_info | |||
from modelscope.utils.utils import if_func_recieve_dict_inputs | |||
from .base import BaseTrainer | |||
from .builder import TRAINERS | |||
from .default_config import DEFAULT_CONFIG | |||
@@ -79,13 +80,15 @@ class EpochBasedTrainer(BaseTrainer): | |||
optimizers: Tuple[torch.optim.Optimizer, | |||
torch.optim.lr_scheduler._LRScheduler] = (None, | |||
None), | |||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
**kwargs): | |||
if isinstance(model, str): | |||
if os.path.exists(model): | |||
self.model_dir = model if os.path.isdir( | |||
model) else os.path.dirname(model) | |||
else: | |||
self.model_dir = snapshot_download(model) | |||
self.model_dir = snapshot_download( | |||
model, revision=model_revision) | |||
cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION) | |||
self.model = self.build_model() | |||
else: | |||
@@ -112,6 +115,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
self.preprocessor = preprocessor | |||
elif hasattr(self.cfg, 'preprocessor'): | |||
self.preprocessor = self.build_preprocessor() | |||
if self.preprocessor is not None: | |||
self.preprocessor.mode = ModeKeys.TRAIN | |||
# TODO @wenmeng.zwm add data collator option | |||
# TODO how to fill device option? | |||
self.device = int( | |||
@@ -264,7 +269,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
model = Model.from_pretrained(self.model_dir) | |||
if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
return model.model | |||
return model | |||
elif isinstance(model, nn.Module): | |||
return model | |||
def collate_fn(self, data): | |||
"""Prepare the input just before the forward function. | |||
@@ -307,7 +313,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
model.train() | |||
self._mode = ModeKeys.TRAIN | |||
inputs = self.collate_fn(inputs) | |||
if not isinstance(model, Model) and isinstance(inputs, dict): | |||
if isinstance(inputs, Mapping) and not if_func_recieve_dict_inputs( | |||
model.forward, inputs): | |||
train_outputs = model.forward(**inputs) | |||
else: | |||
train_outputs = model.forward(inputs) | |||
@@ -5,6 +5,7 @@ import pickle | |||
import shutil | |||
import tempfile | |||
import time | |||
from collections.abc import Mapping | |||
import torch | |||
from torch import distributed as dist | |||
@@ -12,6 +13,7 @@ from tqdm import tqdm | |||
from modelscope.models.base import Model | |||
from modelscope.utils.torch_utils import get_dist_info | |||
from modelscope.utils.utils import if_func_recieve_dict_inputs | |||
def single_gpu_test(model, | |||
@@ -36,7 +38,10 @@ def single_gpu_test(model, | |||
if data_collate_fn is not None: | |||
data = data_collate_fn(data) | |||
with torch.no_grad(): | |||
if not isinstance(model, Model): | |||
if isinstance(data, | |||
Mapping) and not if_func_recieve_dict_inputs( | |||
model.forward, data): | |||
result = model(**data) | |||
else: | |||
result = model(data) | |||
@@ -87,7 +92,9 @@ def multi_gpu_test(model, | |||
if data_collate_fn is not None: | |||
data = data_collate_fn(data) | |||
with torch.no_grad(): | |||
if not isinstance(model, Model): | |||
if isinstance(data, | |||
Mapping) and not if_func_recieve_dict_inputs( | |||
model.forward, data): | |||
result = model(**data) | |||
else: | |||
result = model(data) | |||
@@ -57,6 +57,7 @@ class NLPTasks(object): | |||
summarization = 'summarization' | |||
question_answering = 'question-answering' | |||
zero_shot_classification = 'zero-shot-classification' | |||
backbone = 'backbone' | |||
class AudioTasks(object): | |||
@@ -173,6 +174,7 @@ DEFAULT_DATASET_REVISION = 'master' | |||
class ModeKeys: | |||
TRAIN = 'train' | |||
EVAL = 'eval' | |||
INFERENCE = 'inference' | |||
class LogKeys: | |||
@@ -6,6 +6,7 @@ from typing import List, Tuple, Union | |||
from modelscope.utils.import_utils import requires | |||
from modelscope.utils.logger import get_logger | |||
TYPE_NAME = 'type' | |||
default_group = 'default' | |||
logger = get_logger() | |||
@@ -159,15 +160,16 @@ def build_from_cfg(cfg, | |||
group_key (str, optional): The name of registry group from which | |||
module should be searched. | |||
default_args (dict, optional): Default initialization arguments. | |||
type_name (str, optional): The name of the type in the config. | |||
Returns: | |||
object: The constructed object. | |||
""" | |||
if not isinstance(cfg, dict): | |||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |||
if 'type' not in cfg: | |||
if default_args is None or 'type' not in default_args: | |||
if TYPE_NAME not in cfg: | |||
if default_args is None or TYPE_NAME not in default_args: | |||
raise KeyError( | |||
'`cfg` or `default_args` must contain the key "type", ' | |||
f'`cfg` or `default_args` must contain the key "{TYPE_NAME}", ' | |||
f'but got {cfg}\n{default_args}') | |||
if not isinstance(registry, Registry): | |||
raise TypeError('registry must be an modelscope.Registry object, ' | |||
@@ -184,7 +186,7 @@ def build_from_cfg(cfg, | |||
if group_key is None: | |||
group_key = default_group | |||
obj_type = args.pop('type') | |||
obj_type = args.pop(TYPE_NAME) | |||
if isinstance(obj_type, str): | |||
obj_cls = registry.get(obj_type, group_key=group_key) | |||
if obj_cls is None: | |||
@@ -196,7 +198,10 @@ def build_from_cfg(cfg, | |||
raise TypeError( | |||
f'type must be a str or valid type, but got {type(obj_type)}') | |||
try: | |||
return obj_cls(**args) | |||
if hasattr(obj_cls, '_instantiate'): | |||
return obj_cls._instantiate(**args) | |||
else: | |||
return obj_cls(**args) | |||
except Exception as e: | |||
# Normal TypeError does not print class name. | |||
raise type(e)(f'{obj_cls.__name__}: {e}') |
@@ -2,6 +2,8 @@ | |||
# Part of the implementation is borrowed from huggingface/transformers. | |||
from collections.abc import Mapping | |||
import numpy as np | |||
def torch_nested_numpify(tensors): | |||
import torch | |||
@@ -27,9 +29,6 @@ def torch_nested_detach(tensors): | |||
def torch_default_data_collator(features): | |||
# TODO @jiangnana.jnn refine this default data collator | |||
import torch | |||
# if not isinstance(features[0], (dict, BatchEncoding)): | |||
# features = [vars(f) for f in features] | |||
first = features[0] | |||
if isinstance(first, Mapping): | |||
@@ -40,9 +39,14 @@ def torch_default_data_collator(features): | |||
if 'label' in first and first['label'] is not None: | |||
label = first['label'].item() if isinstance( | |||
first['label'], torch.Tensor) else first['label'] | |||
dtype = torch.long if isinstance(label, int) else torch.float | |||
batch['labels'] = torch.tensor([f['label'] for f in features], | |||
dtype=dtype) | |||
# the msdataset return a 0-dimension np.array with a single value, the following part handle this. | |||
if isinstance(label, np.ndarray): | |||
dtype = torch.long if label[( | |||
)].dtype == np.int64 else torch.float | |||
else: | |||
dtype = torch.long if isinstance(label, int) else torch.float | |||
batch['labels'] = torch.tensor( | |||
np.array([f['label'] for f in features]), dtype=dtype) | |||
elif 'label_ids' in first and first['label_ids'] is not None: | |||
if isinstance(first['label_ids'], torch.Tensor): | |||
batch['labels'] = torch.stack( | |||
@@ -0,0 +1,28 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import inspect | |||
def if_func_recieve_dict_inputs(func, inputs): | |||
"""to decide if a func could recieve dict inputs or not | |||
Args: | |||
func (class): the target function to be inspected | |||
inputs (dicts): the inputs that will send to the function | |||
Returns: | |||
bool: if func recieve dict, then recieve True | |||
Examples: | |||
input = {"input_dict":xxx, "attention_masked":xxx}, | |||
function(self, inputs) then return True | |||
function(inputs) then return True | |||
function(self, input_dict, attention_masked) then return False | |||
""" | |||
signature = inspect.signature(func) | |||
func_inputs = list(signature.parameters.keys() - set(['self'])) | |||
mismatched_inputs = list(set(func_inputs) - set(inputs)) | |||
if len(func_inputs) == len(mismatched_inputs): | |||
return True | |||
else: | |||
return False |
@@ -7,7 +7,7 @@ import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from modelscope.models.base_torch import TorchModel | |||
from modelscope.models.base import TorchModel | |||
class TorchBaseTest(unittest.TestCase): | |||
@@ -3,7 +3,8 @@ import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import SbertForSentimentClassification | |||
from modelscope.models.nlp import (SbertForSentimentClassification, | |||
SequenceClassificationModel) | |||
from modelscope.pipelines import SentimentClassificationPipeline, pipeline | |||
from modelscope.preprocessors import SentimentClassificationPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
@@ -18,39 +19,44 @@ class SentimentClassificationTest(unittest.TestCase): | |||
def test_run_with_direct_file_download(self): | |||
cache_path = snapshot_download(self.model_id) | |||
tokenizer = SentimentClassificationPreprocessor(cache_path) | |||
model = SbertForSentimentClassification( | |||
cache_path, tokenizer=tokenizer) | |||
model = SequenceClassificationModel.from_pretrained( | |||
self.model_id, num_labels=2) | |||
pipeline1 = SentimentClassificationPipeline( | |||
model, preprocessor=tokenizer) | |||
pipeline2 = pipeline( | |||
Tasks.sentiment_classification, | |||
model=model, | |||
preprocessor=tokenizer) | |||
preprocessor=tokenizer, | |||
model_revision='beta') | |||
print(f'sentence1: {self.sentence1}\n' | |||
f'pipeline1:{pipeline1(input=self.sentence1)}') | |||
print() | |||
print(f'sentence1: {self.sentence1}\n' | |||
f'pipeline1: {pipeline2(input=self.sentence1)}') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id) | |||
tokenizer = SentimentClassificationPreprocessor(model.model_dir) | |||
pipeline_ins = pipeline( | |||
task=Tasks.sentiment_classification, | |||
model=model, | |||
preprocessor=tokenizer) | |||
preprocessor=tokenizer, | |||
model_revision='beta') | |||
print(pipeline_ins(input=self.sentence1)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.sentiment_classification, model=self.model_id) | |||
task=Tasks.sentiment_classification, | |||
model=self.model_id, | |||
model_revision='beta') | |||
print(pipeline_ins(input=self.sentence1)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_ins = pipeline(task=Tasks.sentiment_classification) | |||
pipeline_ins = pipeline( | |||
task=Tasks.sentiment_classification, model_revision='beta') | |||
print(pipeline_ins(input=self.sentence1)) | |||
@@ -56,6 +56,23 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
for i in range(10): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_trainer_with_backbone_head(self): | |||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
kwargs = dict( | |||
model=model_id, | |||
train_dataset=self.dataset, | |||
eval_dataset=self.dataset, | |||
work_dir=self.tmp_dir, | |||
model_revision='beta') | |||
trainer = build_trainer(default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(10): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_model_and_args(self): | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||