@@ -6,25 +6,34 @@ | |||||
"second_sequence": "sentence2" | "second_sequence": "sentence2" | ||||
}, | }, | ||||
"model": { | "model": { | ||||
"type": "structbert", | |||||
"attention_probs_dropout_prob": 0.1, | |||||
"easynlp_version": "0.0.3", | |||||
"gradient_checkpointing": false, | |||||
"hidden_act": "gelu", | |||||
"hidden_dropout_prob": 0.1, | |||||
"hidden_size": 768, | |||||
"initializer_range": 0.02, | |||||
"intermediate_size": 3072, | |||||
"layer_norm_eps": 1e-12, | |||||
"max_position_embeddings": 512, | |||||
"num_attention_heads": 12, | |||||
"num_hidden_layers": 12, | |||||
"pad_token_id": 0, | |||||
"position_embedding_type": "absolute", | |||||
"transformers_version": "4.6.0.dev0", | |||||
"type_vocab_size": 2, | |||||
"use_cache": true, | |||||
"vocab_size": 30522 | |||||
"type": "text-classification", | |||||
"backbone": { | |||||
"type": "structbert", | |||||
"prefix": "encoder", | |||||
"attention_probs_dropout_prob": 0.1, | |||||
"easynlp_version": "0.0.3", | |||||
"gradient_checkpointing": false, | |||||
"hidden_act": "gelu", | |||||
"hidden_dropout_prob": 0.1, | |||||
"hidden_size": 768, | |||||
"initializer_range": 0.02, | |||||
"intermediate_size": 3072, | |||||
"layer_norm_eps": 1e-12, | |||||
"max_position_embeddings": 512, | |||||
"num_attention_heads": 12, | |||||
"num_hidden_layers": 12, | |||||
"pad_token_id": 0, | |||||
"position_embedding_type": "absolute", | |||||
"transformers_version": "4.6.0.dev0", | |||||
"type_vocab_size": 2, | |||||
"use_cache": true, | |||||
"vocab_size": 21128 | |||||
}, | |||||
"head": { | |||||
"type": "text-classification", | |||||
"hidden_dropout_prob": 0.1, | |||||
"hidden_size": 768 | |||||
} | |||||
}, | }, | ||||
"pipeline": { | "pipeline": { | ||||
"type": "sentence-similarity" | "type": "sentence-similarity" | ||||
@@ -6,6 +6,9 @@ task: text-classification | |||||
model: | model: | ||||
path: bert-base-sst2 | path: bert-base-sst2 | ||||
backbone: | |||||
type: bert | |||||
prefix: bert | |||||
attention_probs_dropout_prob: 0.1 | attention_probs_dropout_prob: 0.1 | ||||
bos_token_id: 0 | bos_token_id: 0 | ||||
eos_token_id: 2 | eos_token_id: 2 | ||||
@@ -33,6 +33,16 @@ class Models(object): | |||||
imagen = 'imagen-text-to-image-synthesis' | imagen = 'imagen-text-to-image-synthesis' | ||||
class TaskModels(object): | |||||
# nlp task | |||||
text_classification = 'text-classification' | |||||
class Heads(object): | |||||
# nlp heads | |||||
text_classification = 'text-classification' | |||||
class Pipelines(object): | class Pipelines(object): | ||||
""" Names for different pipelines. | """ Names for different pipelines. | ||||
@@ -17,6 +17,7 @@ class MetricKeys(object): | |||||
task_default_metrics = { | task_default_metrics = { | ||||
Tasks.sentence_similarity: [Metrics.seq_cls_metric], | Tasks.sentence_similarity: [Metrics.seq_cls_metric], | ||||
Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||||
Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
} | } | ||||
@@ -29,6 +29,8 @@ try: | |||||
SbertForZeroShotClassification, SpaceForDialogIntent, | SbertForZeroShotClassification, SpaceForDialogIntent, | ||||
SpaceForDialogModeling, SpaceForDialogStateTracking, | SpaceForDialogModeling, SpaceForDialogStateTracking, | ||||
StructBertForMaskedLM, VecoForMaskedLM) | StructBertForMaskedLM, VecoForMaskedLM) | ||||
from .nlp.heads import (SequenceClassificationHead) | |||||
from .nlp.backbones import (SbertModel) | |||||
except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
if str(e) == "No module named 'pytorch'": | if str(e) == "No module named 'pytorch'": | ||||
pass | pass | ||||
@@ -0,0 +1,4 @@ | |||||
from .base_head import * # noqa F403 | |||||
from .base_model import * # noqa F403 | |||||
from .base_torch_head import * # noqa F403 | |||||
from .base_torch_model import * # noqa F403 |
@@ -0,0 +1,48 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from abc import ABC, abstractmethod | |||||
from typing import Dict, List, Union | |||||
import numpy as np | |||||
from ...utils.config import ConfigDict | |||||
from ...utils.logger import get_logger | |||||
from .base_model import Model | |||||
logger = get_logger() | |||||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||||
Input = Union[Dict[str, Tensor], Model] | |||||
class Head(ABC): | |||||
""" | |||||
The head base class is for the tasks head method definition | |||||
""" | |||||
def __init__(self, **kwargs): | |||||
self.config = ConfigDict(kwargs) | |||||
@abstractmethod | |||||
def forward(self, input: Input) -> Dict[str, Tensor]: | |||||
""" | |||||
This method will use the output from backbone model to do any | |||||
downstream tasks | |||||
Args: | |||||
input: The tensor output or a model from backbone model | |||||
(text generation need a model as input) | |||||
Returns: The output from downstream taks | |||||
""" | |||||
pass | |||||
@abstractmethod | |||||
def compute_loss(self, outputs: Dict[str, Tensor], | |||||
labels) -> Dict[str, Tensor]: | |||||
""" | |||||
compute loss for head during the finetuning | |||||
Args: | |||||
outputs (Dict[str, Tensor]): the output from the model forward | |||||
Returns: the loss(Dict[str, Tensor]): | |||||
""" | |||||
pass |
@@ -4,6 +4,8 @@ import os.path as osp | |||||
from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
import numpy as np | |||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.models.builder import build_model | from modelscope.models.builder import build_model | ||||
from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
@@ -25,6 +27,15 @@ class Model(ABC): | |||||
@abstractmethod | @abstractmethod | ||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
""" | |||||
Run the forward pass for a model. | |||||
Args: | |||||
input (Dict[str, Tensor]): the dict of the model inputs for the forward method | |||||
Returns: | |||||
Dict[str, Tensor]: output from the model forward pass | |||||
""" | |||||
pass | pass | ||||
def postprocess(self, input: Dict[str, Tensor], | def postprocess(self, input: Dict[str, Tensor], | ||||
@@ -41,6 +52,15 @@ class Model(ABC): | |||||
""" | """ | ||||
return input | return input | ||||
@classmethod | |||||
def _instantiate(cls, **kwargs): | |||||
""" Define the instantiation method of a model,default method is by | |||||
calling the constructor. Note that in the case of no loading model | |||||
process in constructor of a task model, a load_model method is | |||||
added, and thus this method is overloaded | |||||
""" | |||||
return cls(**kwargs) | |||||
@classmethod | @classmethod | ||||
def from_pretrained(cls, | def from_pretrained(cls, | ||||
model_name_or_path: str, | model_name_or_path: str, | ||||
@@ -71,6 +91,7 @@ class Model(ABC): | |||||
cfg, 'pipeline'), 'pipeline config is missing from config file.' | cfg, 'pipeline'), 'pipeline config is missing from config file.' | ||||
pipeline_cfg = cfg.pipeline | pipeline_cfg = cfg.pipeline | ||||
# TODO @wenmeng.zwm may should manually initialize model after model building | # TODO @wenmeng.zwm may should manually initialize model after model building | ||||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | ||||
model_cfg.type = model_cfg.model_type | model_cfg.type = model_cfg.model_type | ||||
@@ -78,7 +99,8 @@ class Model(ABC): | |||||
for k, v in kwargs.items(): | for k, v in kwargs.items(): | ||||
model_cfg[k] = v | model_cfg[k] = v | ||||
model = build_model(model_cfg, task_name) | |||||
model = build_model( | |||||
model_cfg, task_name=task_name, default_args=kwargs) | |||||
# dynamically add pipeline info to model for pipeline inference | # dynamically add pipeline info to model for pipeline inference | ||||
model.pipeline = pipeline_cfg | model.pipeline = pipeline_cfg |
@@ -0,0 +1,30 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os.path | |||||
import re | |||||
from typing import Dict, Optional, Union | |||||
import torch | |||||
from torch import nn | |||||
from ...utils.logger import get_logger | |||||
from .base_head import Head | |||||
logger = get_logger(__name__) | |||||
class TorchHead(Head, torch.nn.Module): | |||||
""" Base head interface for pytorch | |||||
""" | |||||
def __init__(self, **kwargs): | |||||
super().__init__(**kwargs) | |||||
torch.nn.Module.__init__(self) | |||||
def forward(self, inputs: Dict[str, | |||||
torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
raise NotImplementedError | |||||
def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||||
labels) -> Dict[str, torch.Tensor]: | |||||
raise NotImplementedError |
@@ -0,0 +1,55 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import Any, Dict, Optional, Union | |||||
import torch | |||||
from torch import nn | |||||
from ...utils.logger import get_logger | |||||
from .base_model import Model | |||||
logger = get_logger(__name__) | |||||
class TorchModel(Model, torch.nn.Module): | |||||
""" Base model interface for pytorch | |||||
""" | |||||
def __init__(self, model_dir=None, *args, **kwargs): | |||||
super().__init__(model_dir, *args, **kwargs) | |||||
torch.nn.Module.__init__(self) | |||||
def forward(self, inputs: Dict[str, | |||||
torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
raise NotImplementedError | |||||
def post_init(self): | |||||
""" | |||||
A method executed at the end of each model initialization, to execute code that needs the model's | |||||
modules properly initialized (such as weight initialization). | |||||
""" | |||||
self.init_weights() | |||||
def init_weights(self): | |||||
# Initialize weights | |||||
self.apply(self._init_weights) | |||||
def _init_weights(self, module): | |||||
"""Initialize the weights""" | |||||
if isinstance(module, nn.Linear): | |||||
# Slightly different from the TF version which uses truncated_normal for initialization | |||||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||||
module.weight.data.normal_(mean=0.0, std=0.02) | |||||
if module.bias is not None: | |||||
module.bias.data.zero_() | |||||
elif isinstance(module, nn.Embedding): | |||||
module.weight.data.normal_(mean=0.0, std=0.02) | |||||
if module.padding_idx is not None: | |||||
module.weight.data[module.padding_idx].zero_() | |||||
elif isinstance(module, nn.LayerNorm): | |||||
module.bias.data.zero_() | |||||
module.weight.data.fill_(1.0) | |||||
def compute_loss(self, outputs: Dict[str, Any], labels): | |||||
raise NotImplementedError() |
@@ -1,23 +0,0 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import Dict | |||||
import torch | |||||
from .base import Model | |||||
class TorchModel(Model, torch.nn.Module): | |||||
""" Base model interface for pytorch | |||||
""" | |||||
def __init__(self, model_dir=None, *args, **kwargs): | |||||
# init reference: https://stackoverflow.com/questions\ | |||||
# /9575409/calling-parent-class-init-with-multiple-inheritance-whats-the-right-way | |||||
super().__init__(model_dir) | |||||
super(Model, self).__init__() | |||||
def forward(self, inputs: Dict[str, | |||||
torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
raise NotImplementedError |
@@ -1,9 +1,11 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from modelscope.utils.config import ConfigDict | from modelscope.utils.config import ConfigDict | ||||
from modelscope.utils.registry import Registry, build_from_cfg | |||||
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | |||||
MODELS = Registry('models') | MODELS = Registry('models') | ||||
BACKBONES = Registry('backbones') | |||||
HEADS = Registry('heads') | |||||
def build_model(cfg: ConfigDict, | def build_model(cfg: ConfigDict, | ||||
@@ -19,3 +21,29 @@ def build_model(cfg: ConfigDict, | |||||
""" | """ | ||||
return build_from_cfg( | return build_from_cfg( | ||||
cfg, MODELS, group_key=task_name, default_args=default_args) | cfg, MODELS, group_key=task_name, default_args=default_args) | ||||
def build_backbone(cfg: ConfigDict, | |||||
field: str = None, | |||||
default_args: dict = None): | |||||
""" build backbone given backbone config dict | |||||
Args: | |||||
cfg (:obj:`ConfigDict`): config dict for backbone object. | |||||
field (str, optional): field, such as CV, NLP's backbone | |||||
default_args (dict, optional): Default initialization arguments. | |||||
""" | |||||
return build_from_cfg( | |||||
cfg, BACKBONES, group_key=field, default_args=default_args) | |||||
def build_head(cfg: ConfigDict, default_args: dict = None): | |||||
""" build head given config dict | |||||
Args: | |||||
cfg (:obj:`ConfigDict`): config dict for head object. | |||||
default_args (dict, optional): Default initialization arguments. | |||||
""" | |||||
return build_from_cfg( | |||||
cfg, HEADS, group_key=cfg[TYPE_NAME], default_args=default_args) |
@@ -1,6 +1,8 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from modelscope.utils.error import TENSORFLOW_IMPORT_WARNING | |||||
from ...utils.error import TENSORFLOW_IMPORT_WARNING | |||||
from .backbones import * # noqa F403 | |||||
from .bert_for_sequence_classification import * # noqa F403 | from .bert_for_sequence_classification import * # noqa F403 | ||||
from .heads import * # noqa F403 | |||||
from .masked_language import * # noqa F403 | from .masked_language import * # noqa F403 | ||||
from .nncrf_for_named_entity_recognition import * # noqa F403 | from .nncrf_for_named_entity_recognition import * # noqa F403 | ||||
from .palm_for_text_generation import * # noqa F403 | from .palm_for_text_generation import * # noqa F403 | ||||
@@ -9,9 +11,10 @@ from .sbert_for_sentence_similarity import * # noqa F403 | |||||
from .sbert_for_sentiment_classification import * # noqa F403 | from .sbert_for_sentiment_classification import * # noqa F403 | ||||
from .sbert_for_token_classification import * # noqa F403 | from .sbert_for_token_classification import * # noqa F403 | ||||
from .sbert_for_zero_shot_classification import * # noqa F403 | from .sbert_for_zero_shot_classification import * # noqa F403 | ||||
from .space.dialog_intent_prediction_model import * # noqa F403 | |||||
from .space.dialog_modeling_model import * # noqa F403 | |||||
from .space.dialog_state_tracking_model import * # noqa F403 | |||||
from .sequence_classification import * # noqa F403 | |||||
from .space_for_dialog_intent_prediction import * # noqa F403 | |||||
from .space_for_dialog_modeling import * # noqa F403 | |||||
from .space_for_dialog_state_tracking import * # noqa F403 | |||||
try: | try: | ||||
from .csanmt_for_translation import CsanmtForTranslation | from .csanmt_for_translation import CsanmtForTranslation | ||||
@@ -0,0 +1,4 @@ | |||||
from .space import SpaceGenerator, SpaceModelBase | |||||
from .structbert import SbertModel | |||||
__all__ = ['SbertModel', 'SpaceGenerator', 'SpaceModelBase'] |
@@ -0,0 +1,2 @@ | |||||
from .model.generator import Generator as SpaceGenerator | |||||
from .model.model_base import SpaceModelBase |
@@ -4,7 +4,7 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from .....utils.nlp.space.criterions import compute_kl_loss | |||||
from ......utils.nlp.space.criterions import compute_kl_loss | |||||
from .unified_transformer import UnifiedTransformer | from .unified_transformer import UnifiedTransformer | ||||
@@ -4,7 +4,7 @@ import os | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from .....utils.constant import ModelFile | |||||
from ......utils.constant import ModelFile | |||||
class SpaceModelBase(nn.Module): | class SpaceModelBase(nn.Module): |
@@ -0,0 +1 @@ | |||||
from .modeling_sbert import SbertModel |
@@ -0,0 +1,166 @@ | |||||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||||
# All rights reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import torch | |||||
from torch import nn | |||||
from .....utils.logger import get_logger | |||||
logger = get_logger(__name__) | |||||
def _symmetric_kl_div(logits1, logits2, attention_mask=None): | |||||
""" | |||||
Calclate two logits' the KL div value symmetrically. | |||||
:param logits1: The first logit. | |||||
:param logits2: The second logit. | |||||
:param attention_mask: An optional attention_mask which is used to mask some element out. | |||||
This is usually useful in token_classification tasks. | |||||
If the shape of logits is [N1, N2, ... Nn, D], the shape of attention_mask should be [N1, N2, ... Nn] | |||||
:return: The mean loss. | |||||
""" | |||||
labels_num = logits1.shape[-1] | |||||
KLDiv = nn.KLDivLoss(reduction='none') | |||||
loss = torch.sum( | |||||
KLDiv(nn.LogSoftmax(dim=-1)(logits1), | |||||
nn.Softmax(dim=-1)(logits2)), | |||||
dim=-1) + torch.sum( | |||||
KLDiv(nn.LogSoftmax(dim=-1)(logits2), | |||||
nn.Softmax(dim=-1)(logits1)), | |||||
dim=-1) | |||||
if attention_mask is not None: | |||||
loss = torch.sum( | |||||
loss * attention_mask) / torch.sum(attention_mask) / labels_num | |||||
else: | |||||
loss = torch.mean(loss) / labels_num | |||||
return loss | |||||
def compute_adv_loss(embedding, | |||||
model, | |||||
ori_logits, | |||||
ori_loss, | |||||
adv_grad_factor, | |||||
adv_bound=None, | |||||
sigma=5e-6, | |||||
**kwargs): | |||||
""" | |||||
Calculate the adv loss of the model. | |||||
:param embedding: Original sentense embedding | |||||
:param model: The model or the forward function(including decoder/classifier), accept kwargs as input, output logits | |||||
:param ori_logits: The original logits outputed from the model function | |||||
:param ori_loss: The original loss | |||||
:param adv_grad_factor: This factor will be multipled by the KL loss grad and then the result will be added to | |||||
the original embedding. | |||||
More details please check:https://arxiv.org/abs/1908.04577 | |||||
The range of this value always be 1e-3~1e-7 | |||||
:param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding. | |||||
If not proveded, 2 * sigma will be used as the adv_bound factor | |||||
:param sigma: The std factor used to produce a 0 mean normal distribution. | |||||
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | |||||
:param kwargs: the input param used in model function | |||||
:return: The original loss adds the adv loss | |||||
""" | |||||
adv_bound = adv_bound if adv_bound is not None else 2 * sigma | |||||
embedding_1 = embedding + embedding.data.new(embedding.size()).normal_( | |||||
0, sigma) # 95% in +- 1e-5 | |||||
kwargs.pop('input_ids') | |||||
if 'inputs_embeds' in kwargs: | |||||
kwargs.pop('inputs_embeds') | |||||
with_attention_mask = False if 'with_attention_mask' not in kwargs else kwargs[ | |||||
'with_attention_mask'] | |||||
attention_mask = kwargs['attention_mask'] | |||||
if not with_attention_mask: | |||||
attention_mask = None | |||||
if 'with_attention_mask' in kwargs: | |||||
kwargs.pop('with_attention_mask') | |||||
outputs = model(**kwargs, inputs_embeds=embedding_1) | |||||
v1_logits = outputs.logits | |||||
loss = _symmetric_kl_div(ori_logits, v1_logits, attention_mask) | |||||
emb_grad = torch.autograd.grad(loss, embedding_1)[0].data | |||||
emb_grad_norm = emb_grad.norm( | |||||
dim=2, keepdim=True, p=float('inf')).max( | |||||
1, keepdim=True)[0] | |||||
is_nan = torch.any(torch.isnan(emb_grad_norm)) | |||||
if is_nan: | |||||
logger.warning('Nan occured when calculating adv loss.') | |||||
return ori_loss | |||||
emb_grad = emb_grad / emb_grad_norm | |||||
embedding_2 = embedding_1 + adv_grad_factor * emb_grad | |||||
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) | |||||
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) | |||||
outputs = model(**kwargs, inputs_embeds=embedding_2) | |||||
adv_logits = outputs.logits | |||||
adv_loss = _symmetric_kl_div(ori_logits, adv_logits, attention_mask) | |||||
return ori_loss + adv_loss | |||||
def compute_adv_loss_pair(embedding, | |||||
model, | |||||
start_logits, | |||||
end_logits, | |||||
ori_loss, | |||||
adv_grad_factor, | |||||
adv_bound=None, | |||||
sigma=5e-6, | |||||
**kwargs): | |||||
""" | |||||
Calculate the adv loss of the model. This function is used in the pair logits scenerio. | |||||
:param embedding: Original sentense embedding | |||||
:param model: The model or the forward function(including decoder/classifier), accept kwargs as input, output logits | |||||
:param start_logits: The original start logits outputed from the model function | |||||
:param end_logits: The original end logits outputed from the model function | |||||
:param ori_loss: The original loss | |||||
:param adv_grad_factor: This factor will be multipled by the KL loss grad and then the result will be added to | |||||
the original embedding. | |||||
More details please check:https://arxiv.org/abs/1908.04577 | |||||
The range of this value always be 1e-3~1e-7 | |||||
:param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding. | |||||
If not proveded, 2 * sigma will be used as the adv_bound factor | |||||
:param sigma: The std factor used to produce a 0 mean normal distribution. | |||||
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | |||||
:param kwargs: the input param used in model function | |||||
:return: The original loss adds the adv loss | |||||
""" | |||||
adv_bound = adv_bound if adv_bound is not None else 2 * sigma | |||||
embedding_1 = embedding + embedding.data.new(embedding.size()).normal_( | |||||
0, sigma) # 95% in +- 1e-5 | |||||
kwargs.pop('input_ids') | |||||
if 'inputs_embeds' in kwargs: | |||||
kwargs.pop('inputs_embeds') | |||||
outputs = model(**kwargs, inputs_embeds=embedding_1) | |||||
v1_logits_start, v1_logits_end = outputs.logits | |||||
loss = _symmetric_kl_div(start_logits, | |||||
v1_logits_start) + _symmetric_kl_div( | |||||
end_logits, v1_logits_end) | |||||
loss = loss / 2 | |||||
emb_grad = torch.autograd.grad(loss, embedding_1)[0].data | |||||
emb_grad_norm = emb_grad.norm( | |||||
dim=2, keepdim=True, p=float('inf')).max( | |||||
1, keepdim=True)[0] | |||||
is_nan = torch.any(torch.isnan(emb_grad_norm)) | |||||
if is_nan: | |||||
logger.warning('Nan occured when calculating pair adv loss.') | |||||
return ori_loss | |||||
emb_grad = emb_grad / emb_grad_norm | |||||
embedding_2 = embedding_1 + adv_grad_factor * emb_grad | |||||
embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) | |||||
embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) | |||||
outputs = model(**kwargs, inputs_embeds=embedding_2) | |||||
adv_logits_start, adv_logits_end = outputs.logits | |||||
adv_loss = _symmetric_kl_div(start_logits, | |||||
adv_logits_start) + _symmetric_kl_div( | |||||
end_logits, adv_logits_end) | |||||
return ori_loss + adv_loss |
@@ -0,0 +1,131 @@ | |||||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||||
# All rights reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" SBERT model configuration, mainly copied from :class:`~transformers.BertConfig` """ | |||||
from transformers import PretrainedConfig | |||||
from .....utils import logger as logging | |||||
logger = logging.get_logger(__name__) | |||||
class SbertConfig(PretrainedConfig): | |||||
r""" | |||||
This is the configuration class to store the configuration of a :class:`~sofa.models.SbertModel`. | |||||
It is used to instantiate a SBERT model according to the specified arguments. | |||||
Configuration objects inherit from :class:`~sofa.utils.PretrainedConfig` and can be used to control the model | |||||
outputs. Read the documentation from :class:`~sofa.utils.PretrainedConfig` for more information. | |||||
Args: | |||||
vocab_size (:obj:`int`, `optional`, defaults to 30522): | |||||
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the | |||||
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or | |||||
:class:`~transformers.TFBertModel`. | |||||
hidden_size (:obj:`int`, `optional`, defaults to 768): | |||||
Dimensionality of the encoder layers and the pooler layer. | |||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12): | |||||
Number of hidden layers in the Transformer encoder. | |||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12): | |||||
Number of attention heads for each attention layer in the Transformer encoder. | |||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072): | |||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. | |||||
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): | |||||
The non-linear activation function (function or string) in the encoder and pooler. If string, | |||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. | |||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): | |||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): | |||||
The dropout ratio for the attention probabilities. | |||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512): | |||||
The maximum sequence length that this model might ever be used with. Typically set this to something large | |||||
just in case (e.g., 512 or 1024 or 2048). | |||||
type_vocab_size (:obj:`int`, `optional`, defaults to 2): | |||||
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or | |||||
:class:`~transformers.TFBertModel`. | |||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02): | |||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): | |||||
The epsilon used by the layer normalization layers. | |||||
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): | |||||
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, | |||||
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on | |||||
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.) | |||||
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to | |||||
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) | |||||
<https://arxiv.org/abs/2009.13658>`__. | |||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): | |||||
Whether or not the model should return the last key/values attentions (not used by all models). Only | |||||
relevant if ``config.is_decoder=True``. | |||||
classifier_dropout (:obj:`float`, `optional`): | |||||
The dropout ratio for the classification head. | |||||
adv_grad_factor (:obj:`float`, `optional`): This factor will be multipled by the KL loss grad and then | |||||
the result will be added to the original embedding. | |||||
More details please check:https://arxiv.org/abs/1908.04577 | |||||
The range of this value always be 1e-3~1e-7 | |||||
adv_bound (:obj:`float`, `optional`): adv_bound is used to cut the top and the bottom bound of | |||||
the produced embedding. | |||||
If not proveded, 2 * sigma will be used as the adv_bound factor | |||||
sigma (:obj:`float`, `optional`): The std factor used to produce a 0 mean normal distribution. | |||||
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | |||||
""" | |||||
model_type = 'sbert' | |||||
def __init__(self, | |||||
vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act='gelu', | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02, | |||||
layer_norm_eps=1e-12, | |||||
position_embedding_type='absolute', | |||||
use_cache=True, | |||||
classifier_dropout=None, | |||||
**kwargs): | |||||
super().__init__(**kwargs) | |||||
self.vocab_size = vocab_size | |||||
self.hidden_size = hidden_size | |||||
self.num_hidden_layers = num_hidden_layers | |||||
self.num_attention_heads = num_attention_heads | |||||
self.hidden_act = hidden_act | |||||
self.intermediate_size = intermediate_size | |||||
self.hidden_dropout_prob = hidden_dropout_prob | |||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
self.max_position_embeddings = max_position_embeddings | |||||
self.type_vocab_size = type_vocab_size | |||||
self.initializer_range = initializer_range | |||||
self.layer_norm_eps = layer_norm_eps | |||||
self.position_embedding_type = position_embedding_type | |||||
self.use_cache = use_cache | |||||
self.classifier_dropout = classifier_dropout | |||||
# adv_grad_factor, used in adv loss. | |||||
# Users can check adv_utils.py for details. | |||||
# if adv_grad_factor set to None, no adv loss will not applied to the model. | |||||
self.adv_grad_factor = 5e-5 if 'adv_grad_factor' not in kwargs else kwargs[ | |||||
'adv_grad_factor'] | |||||
# sigma value, used in adv loss. | |||||
self.sigma = 5e-6 if 'sigma' not in kwargs else kwargs['sigma'] | |||||
# adv_bound value, used in adv loss. | |||||
self.adv_bound = 2 * self.sigma if 'adv_bound' not in kwargs else kwargs[ | |||||
'adv_bound'] |
@@ -0,0 +1,815 @@ | |||||
import math | |||||
from dataclasses import dataclass | |||||
from typing import Optional, Tuple, Union | |||||
import torch | |||||
import torch.utils.checkpoint | |||||
from packaging import version | |||||
from torch import nn | |||||
from transformers import PreTrainedModel | |||||
from transformers.activations import ACT2FN | |||||
from transformers.modeling_outputs import ( | |||||
BaseModelOutputWithPastAndCrossAttentions, | |||||
BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput) | |||||
from transformers.modeling_utils import (apply_chunking_to_forward, | |||||
find_pruneable_heads_and_indices, | |||||
prune_linear_layer) | |||||
from .....metainfo import Models | |||||
from .....utils.constant import Fields | |||||
from .....utils.logger import get_logger | |||||
from ....base import TorchModel | |||||
from ....builder import BACKBONES | |||||
from .configuration_sbert import SbertConfig | |||||
logger = get_logger(__name__) | |||||
@BACKBONES.register_module(Fields.nlp, module_name=Models.structbert) | |||||
class SbertModel(TorchModel, PreTrainedModel): | |||||
""" | |||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of | |||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is | |||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, | |||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. | |||||
To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration | |||||
set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` | |||||
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an | |||||
input to the forward pass. | |||||
""" | |||||
def __init__(self, model_dir=None, add_pooling_layer=True, **config): | |||||
""" | |||||
Args: | |||||
model_dir (str, optional): The model checkpoint directory. Defaults to None. | |||||
add_pooling_layer (bool, optional): to decide if pool the output from hidden layer. Defaults to True. | |||||
""" | |||||
config = SbertConfig(**config) | |||||
super().__init__(model_dir) | |||||
self.config = config | |||||
self.embeddings = SbertEmbeddings(config) | |||||
self.encoder = SbertEncoder(config) | |||||
self.pooler = SbertPooler(config) if add_pooling_layer else None | |||||
self.init_weights() | |||||
def get_input_embeddings(self): | |||||
return self.embeddings.word_embeddings | |||||
def set_input_embeddings(self, value): | |||||
self.embeddings.word_embeddings = value | |||||
def _prune_heads(self, heads_to_prune): | |||||
""" | |||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base | |||||
class PreTrainedModel | |||||
""" | |||||
for layer, heads in heads_to_prune.items(): | |||||
self.encoder.layer[layer].attention.prune_heads(heads) | |||||
def forward(self, | |||||
input_ids=None, | |||||
attention_mask=None, | |||||
token_type_ids=None, | |||||
position_ids=None, | |||||
head_mask=None, | |||||
inputs_embeds=None, | |||||
encoder_hidden_states=None, | |||||
encoder_attention_mask=None, | |||||
past_key_values=None, | |||||
use_cache=None, | |||||
output_attentions=None, | |||||
output_hidden_states=None, | |||||
return_dict=None, | |||||
**kwargs): | |||||
r""" | |||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)` | |||||
, `optional`): | |||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if | |||||
the model is configured as a decoder. | |||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in | |||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: | |||||
- 1 for tokens that are **not masked**, | |||||
- 0 for tokens that are **masked**. | |||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` | |||||
with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, | |||||
sequence_length - 1, embed_size_per_head)`): | |||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. | |||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` | |||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` | |||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. | |||||
use_cache (:obj:`bool`, `optional`): | |||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up | |||||
decoding (see :obj:`past_key_values`). | |||||
""" | |||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |||||
output_hidden_states = ( | |||||
output_hidden_states if output_hidden_states is not None else | |||||
self.config.output_hidden_states) | |||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||||
if self.config.is_decoder: | |||||
use_cache = use_cache if use_cache is not None else self.config.use_cache | |||||
else: | |||||
use_cache = False | |||||
if input_ids is not None and inputs_embeds is not None: | |||||
raise ValueError( | |||||
'You cannot specify both input_ids and inputs_embeds at the same time' | |||||
) | |||||
elif input_ids is not None: | |||||
input_shape = input_ids.size() | |||||
elif inputs_embeds is not None: | |||||
input_shape = inputs_embeds.size()[:-1] | |||||
else: | |||||
raise ValueError( | |||||
'You have to specify either input_ids or inputs_embeds') | |||||
batch_size, seq_length = input_shape | |||||
device = input_ids.device if input_ids is not None else inputs_embeds.device | |||||
# past_key_values_length | |||||
past_key_values_length = past_key_values[0][0].shape[ | |||||
2] if past_key_values is not None else 0 | |||||
if attention_mask is None: | |||||
attention_mask = torch.ones( | |||||
((batch_size, seq_length + past_key_values_length)), | |||||
device=device) | |||||
if token_type_ids is None: | |||||
if hasattr(self.embeddings, 'token_type_ids'): | |||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, : | |||||
seq_length] | |||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( | |||||
batch_size, seq_length) | |||||
token_type_ids = buffered_token_type_ids_expanded | |||||
else: | |||||
token_type_ids = torch.zeros( | |||||
input_shape, dtype=torch.long, device=device) | |||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||||
# ourselves in which case we just need to make it broadcastable to all heads. | |||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( | |||||
attention_mask, input_shape, device) | |||||
# If a 2D or 3D attention mask is provided for the cross-attention | |||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | |||||
if self.config.is_decoder and encoder_hidden_states is not None: | |||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( | |||||
) | |||||
encoder_hidden_shape = (encoder_batch_size, | |||||
encoder_sequence_length) | |||||
if encoder_attention_mask is None: | |||||
encoder_attention_mask = torch.ones( | |||||
encoder_hidden_shape, device=device) | |||||
encoder_extended_attention_mask = self.invert_attention_mask( | |||||
encoder_attention_mask) | |||||
else: | |||||
encoder_extended_attention_mask = None | |||||
# Prepare head mask if needed | |||||
# 1.0 in head_mask indicate we keep the head | |||||
# attention_probs has shape bsz x n_heads x N x N | |||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | |||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | |||||
head_mask = self.get_head_mask(head_mask, | |||||
self.config.num_hidden_layers) | |||||
embedding_output, orignal_embeds = self.embeddings( | |||||
input_ids=input_ids, | |||||
position_ids=position_ids, | |||||
token_type_ids=token_type_ids, | |||||
inputs_embeds=inputs_embeds, | |||||
past_key_values_length=past_key_values_length, | |||||
return_inputs_embeds=True, | |||||
) | |||||
encoder_outputs = self.encoder( | |||||
embedding_output, | |||||
attention_mask=extended_attention_mask, | |||||
head_mask=head_mask, | |||||
encoder_hidden_states=encoder_hidden_states, | |||||
encoder_attention_mask=encoder_extended_attention_mask, | |||||
past_key_values=past_key_values, | |||||
use_cache=use_cache, | |||||
output_attentions=output_attentions, | |||||
output_hidden_states=output_hidden_states, | |||||
return_dict=return_dict, | |||||
) | |||||
sequence_output = encoder_outputs[0] | |||||
pooled_output = self.pooler( | |||||
sequence_output) if self.pooler is not None else None | |||||
if not return_dict: | |||||
return (sequence_output, | |||||
pooled_output) + encoder_outputs[1:] + (orignal_embeds, ) | |||||
return BaseModelOutputWithPoolingAndCrossAttentionsWithEmbedding( | |||||
last_hidden_state=sequence_output, | |||||
pooler_output=pooled_output, | |||||
past_key_values=encoder_outputs.past_key_values, | |||||
hidden_states=encoder_outputs.hidden_states, | |||||
attentions=encoder_outputs.attentions, | |||||
cross_attentions=encoder_outputs.cross_attentions, | |||||
embedding_output=orignal_embeds) | |||||
def extract_sequence_outputs(self, outputs): | |||||
return outputs['last_hidden_state'] | |||||
def extract_pooled_outputs(self, outputs): | |||||
return outputs['pooler_output'] | |||||
class SbertEmbeddings(nn.Module): | |||||
"""Construct the embeddings from word, position and token_type embeddings.""" | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.word_embeddings = nn.Embedding( | |||||
config.vocab_size, | |||||
config.hidden_size, | |||||
padding_idx=config.pad_token_id) | |||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||||
config.hidden_size) | |||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, | |||||
config.hidden_size) | |||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||||
# any TensorFlow checkpoint file | |||||
self.LayerNorm = nn.LayerNorm( | |||||
config.hidden_size, eps=config.layer_norm_eps) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized | |||||
self.position_embedding_type = getattr(config, | |||||
'position_embedding_type', | |||||
'absolute') | |||||
self.register_buffer( | |||||
'position_ids', | |||||
torch.arange(config.max_position_embeddings).expand((1, -1))) | |||||
if version.parse(torch.__version__) > version.parse('1.6.0'): | |||||
self.register_buffer( | |||||
'token_type_ids', | |||||
torch.zeros( | |||||
self.position_ids.size(), | |||||
dtype=torch.long, | |||||
device=self.position_ids.device), | |||||
persistent=False, | |||||
) | |||||
def forward(self, | |||||
input_ids=None, | |||||
token_type_ids=None, | |||||
position_ids=None, | |||||
inputs_embeds=None, | |||||
past_key_values_length=0, | |||||
return_inputs_embeds=False): | |||||
if input_ids is not None: | |||||
input_shape = input_ids.size() | |||||
else: | |||||
input_shape = inputs_embeds.size()[:-1] | |||||
seq_length = input_shape[1] | |||||
if position_ids is None: | |||||
position_ids = self.position_ids[:, | |||||
past_key_values_length:seq_length | |||||
+ past_key_values_length] | |||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs | |||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids | |||||
# issue #5664 | |||||
if token_type_ids is None: | |||||
if hasattr(self, 'token_type_ids'): | |||||
buffered_token_type_ids = self.token_type_ids[:, :seq_length] | |||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( | |||||
input_shape[0], seq_length) | |||||
token_type_ids = buffered_token_type_ids_expanded | |||||
else: | |||||
token_type_ids = torch.zeros( | |||||
input_shape, | |||||
dtype=torch.long, | |||||
device=self.position_ids.device) | |||||
if inputs_embeds is None: | |||||
inputs_embeds = self.word_embeddings(input_ids) | |||||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||||
embeddings = inputs_embeds + token_type_embeddings | |||||
if self.position_embedding_type == 'absolute': | |||||
position_embeddings = self.position_embeddings(position_ids) | |||||
embeddings += position_embeddings | |||||
embeddings = self.LayerNorm(embeddings) | |||||
embeddings = self.dropout(embeddings) | |||||
if not return_inputs_embeds: | |||||
return embeddings | |||||
else: | |||||
return embeddings, inputs_embeds | |||||
class SbertSelfAttention(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr( | |||||
config, 'embedding_size'): | |||||
raise ValueError( | |||||
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' | |||||
f'heads ({config.num_attention_heads})') | |||||
self.num_attention_heads = config.num_attention_heads | |||||
self.attention_head_size = int(config.hidden_size | |||||
/ config.num_attention_heads) | |||||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||||
self.query = nn.Linear(config.hidden_size, self.all_head_size) | |||||
self.key = nn.Linear(config.hidden_size, self.all_head_size) | |||||
self.value = nn.Linear(config.hidden_size, self.all_head_size) | |||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |||||
self.position_embedding_type = getattr(config, | |||||
'position_embedding_type', | |||||
'absolute') | |||||
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': | |||||
self.max_position_embeddings = config.max_position_embeddings | |||||
self.distance_embedding = nn.Embedding( | |||||
2 * config.max_position_embeddings - 1, | |||||
self.attention_head_size) | |||||
self.is_decoder = config.is_decoder | |||||
def transpose_for_scores(self, x): | |||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, | |||||
self.attention_head_size) | |||||
x = x.view(*new_x_shape) | |||||
return x.permute(0, 2, 1, 3) | |||||
def forward( | |||||
self, | |||||
hidden_states, | |||||
attention_mask=None, | |||||
head_mask=None, | |||||
encoder_hidden_states=None, | |||||
encoder_attention_mask=None, | |||||
past_key_value=None, | |||||
output_attentions=False, | |||||
): | |||||
mixed_query_layer = self.query(hidden_states) | |||||
# If this is instantiated as a cross-attention module, the keys | |||||
# and values come from an encoder; the attention mask needs to be | |||||
# such that the encoder's padding tokens are not attended to. | |||||
is_cross_attention = encoder_hidden_states is not None | |||||
if is_cross_attention and past_key_value is not None: | |||||
# reuse k,v, cross_attentions | |||||
key_layer = past_key_value[0] | |||||
value_layer = past_key_value[1] | |||||
attention_mask = encoder_attention_mask | |||||
elif is_cross_attention: | |||||
key_layer = self.transpose_for_scores( | |||||
self.key(encoder_hidden_states)) | |||||
value_layer = self.transpose_for_scores( | |||||
self.value(encoder_hidden_states)) | |||||
attention_mask = encoder_attention_mask | |||||
elif past_key_value is not None: | |||||
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |||||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) | |||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) | |||||
else: | |||||
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |||||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |||||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||||
if self.is_decoder: | |||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | |||||
# Further calls to cross_attention layer can then reuse all cross-attention | |||||
# key/value_states (first "if" case) | |||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | |||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention | |||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | |||||
# if encoder bi-directional self-attention `past_key_value` is always `None` | |||||
past_key_value = (key_layer, value_layer) | |||||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||||
attention_scores = torch.matmul(query_layer, | |||||
key_layer.transpose(-1, -2)) | |||||
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': | |||||
seq_length = hidden_states.size()[1] | |||||
position_ids_l = torch.arange( | |||||
seq_length, dtype=torch.long, | |||||
device=hidden_states.device).view(-1, 1) | |||||
position_ids_r = torch.arange( | |||||
seq_length, dtype=torch.long, | |||||
device=hidden_states.device).view(1, -1) | |||||
distance = position_ids_l - position_ids_r | |||||
positional_embedding = self.distance_embedding( | |||||
distance + self.max_position_embeddings - 1) | |||||
positional_embedding = positional_embedding.to( | |||||
dtype=query_layer.dtype) # fp16 compatibility | |||||
if self.position_embedding_type == 'relative_key': | |||||
relative_position_scores = torch.einsum( | |||||
'bhld,lrd->bhlr', query_layer, positional_embedding) | |||||
attention_scores = attention_scores + relative_position_scores | |||||
elif self.position_embedding_type == 'relative_key_query': | |||||
relative_position_scores_query = torch.einsum( | |||||
'bhld,lrd->bhlr', query_layer, positional_embedding) | |||||
relative_position_scores_key = torch.einsum( | |||||
'bhrd,lrd->bhlr', key_layer, positional_embedding) | |||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key | |||||
attention_scores = attention_scores / math.sqrt( | |||||
self.attention_head_size) | |||||
if attention_mask is not None: | |||||
# Apply the attention mask is (precomputed for all layers in SbertModel forward() function) | |||||
attention_scores = attention_scores + attention_mask | |||||
# Normalize the attention scores to probabilities. | |||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||||
# This is actually dropping out entire tokens to attend to, which might | |||||
# seem a bit unusual, but is taken from the original Transformer paper. | |||||
attention_probs = self.dropout(attention_probs) | |||||
# Mask heads if we want to | |||||
if head_mask is not None: | |||||
attention_probs = attention_probs * head_mask | |||||
context_layer = torch.matmul(attention_probs, value_layer) | |||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||||
new_context_layer_shape = context_layer.size()[:-2] + ( | |||||
self.all_head_size, ) | |||||
context_layer = context_layer.view(*new_context_layer_shape) | |||||
outputs = (context_layer, | |||||
attention_probs) if output_attentions else (context_layer, ) | |||||
if self.is_decoder: | |||||
outputs = outputs + (past_key_value, ) | |||||
return outputs | |||||
class SbertSelfOutput(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||||
self.LayerNorm = nn.LayerNorm( | |||||
config.hidden_size, eps=config.layer_norm_eps) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class SbertAttention(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.self = SbertSelfAttention(config) | |||||
self.output = SbertSelfOutput(config) | |||||
self.pruned_heads = set() | |||||
def prune_heads(self, heads): | |||||
if len(heads) == 0: | |||||
return | |||||
heads, index = find_pruneable_heads_and_indices( | |||||
heads, self.self.num_attention_heads, | |||||
self.self.attention_head_size, self.pruned_heads) | |||||
# Prune linear layers | |||||
self.self.query = prune_linear_layer(self.self.query, index) | |||||
self.self.key = prune_linear_layer(self.self.key, index) | |||||
self.self.value = prune_linear_layer(self.self.value, index) | |||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) | |||||
# Update hyper params and store pruned heads | |||||
self.self.num_attention_heads = self.self.num_attention_heads - len( | |||||
heads) | |||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads | |||||
self.pruned_heads = self.pruned_heads.union(heads) | |||||
def forward( | |||||
self, | |||||
hidden_states, | |||||
attention_mask=None, | |||||
head_mask=None, | |||||
encoder_hidden_states=None, | |||||
encoder_attention_mask=None, | |||||
past_key_value=None, | |||||
output_attentions=False, | |||||
): | |||||
self_outputs = self.self( | |||||
hidden_states, | |||||
attention_mask, | |||||
head_mask, | |||||
encoder_hidden_states, | |||||
encoder_attention_mask, | |||||
past_key_value, | |||||
output_attentions, | |||||
) | |||||
attention_output = self.output(self_outputs[0], hidden_states) | |||||
outputs = (attention_output, | |||||
) + self_outputs[1:] # add attentions if we output them | |||||
return outputs | |||||
class SbertIntermediate(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |||||
if isinstance(config.hidden_act, str): | |||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] | |||||
else: | |||||
self.intermediate_act_fn = config.hidden_act | |||||
def forward(self, hidden_states): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.intermediate_act_fn(hidden_states) | |||||
return hidden_states | |||||
class SbertOutput(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |||||
self.LayerNorm = nn.LayerNorm( | |||||
config.hidden_size, eps=config.layer_norm_eps) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class SbertLayer(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward | |||||
self.seq_len_dim = 1 | |||||
self.attention = SbertAttention(config) | |||||
self.is_decoder = config.is_decoder | |||||
self.add_cross_attention = config.add_cross_attention | |||||
if self.add_cross_attention: | |||||
if not self.is_decoder: | |||||
raise ValueError( | |||||
f'{self} should be used as a decoder model if cross attention is added' | |||||
) | |||||
self.crossattention = SbertAttention(config) | |||||
self.intermediate = SbertIntermediate(config) | |||||
self.output = SbertOutput(config) | |||||
def forward( | |||||
self, | |||||
hidden_states, | |||||
attention_mask=None, | |||||
head_mask=None, | |||||
encoder_hidden_states=None, | |||||
encoder_attention_mask=None, | |||||
past_key_value=None, | |||||
output_attentions=False, | |||||
): | |||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 | |||||
self_attn_past_key_value = past_key_value[: | |||||
2] if past_key_value is not None else None | |||||
self_attention_outputs = self.attention( | |||||
hidden_states, | |||||
attention_mask, | |||||
head_mask, | |||||
output_attentions=output_attentions, | |||||
past_key_value=self_attn_past_key_value, | |||||
) | |||||
attention_output = self_attention_outputs[0] | |||||
# if decoder, the last output is tuple of self-attn cache | |||||
if self.is_decoder: | |||||
outputs = self_attention_outputs[1:-1] | |||||
present_key_value = self_attention_outputs[-1] | |||||
else: | |||||
outputs = self_attention_outputs[ | |||||
1:] # add self attentions if we output attention weights | |||||
cross_attn_present_key_value = None | |||||
if self.is_decoder and encoder_hidden_states is not None: | |||||
if not hasattr(self, 'crossattention'): | |||||
raise ValueError( | |||||
f'If `encoder_hidden_states` are passed, {self} has to be instantiated' | |||||
f'with cross-attention layers by setting `config.add_cross_attention=True`' | |||||
) | |||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple | |||||
cross_attn_past_key_value = past_key_value[ | |||||
-2:] if past_key_value is not None else None | |||||
cross_attention_outputs = self.crossattention( | |||||
attention_output, | |||||
attention_mask, | |||||
head_mask, | |||||
encoder_hidden_states, | |||||
encoder_attention_mask, | |||||
cross_attn_past_key_value, | |||||
output_attentions, | |||||
) | |||||
attention_output = cross_attention_outputs[0] | |||||
outputs = outputs + cross_attention_outputs[ | |||||
1:-1] # add cross attentions if we output attention weights | |||||
# add cross-attn cache to positions 3,4 of present_key_value tuple | |||||
cross_attn_present_key_value = cross_attention_outputs[-1] | |||||
present_key_value = present_key_value + cross_attn_present_key_value | |||||
layer_output = apply_chunking_to_forward(self.feed_forward_chunk, | |||||
self.chunk_size_feed_forward, | |||||
self.seq_len_dim, | |||||
attention_output) | |||||
outputs = (layer_output, ) + outputs | |||||
# if decoder, return the attn key/values as the last output | |||||
if self.is_decoder: | |||||
outputs = outputs + (present_key_value, ) | |||||
return outputs | |||||
def feed_forward_chunk(self, attention_output): | |||||
intermediate_output = self.intermediate(attention_output) | |||||
layer_output = self.output(intermediate_output, attention_output) | |||||
return layer_output | |||||
class SbertEncoder(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.config = config | |||||
self.layer = nn.ModuleList( | |||||
[SbertLayer(config) for _ in range(config.num_hidden_layers)]) | |||||
self.gradient_checkpointing = False | |||||
def forward( | |||||
self, | |||||
hidden_states, | |||||
attention_mask=None, | |||||
head_mask=None, | |||||
encoder_hidden_states=None, | |||||
encoder_attention_mask=None, | |||||
past_key_values=None, | |||||
use_cache=None, | |||||
output_attentions=False, | |||||
output_hidden_states=False, | |||||
return_dict=True, | |||||
): | |||||
all_hidden_states = () if output_hidden_states else None | |||||
all_self_attentions = () if output_attentions else None | |||||
all_cross_attentions = ( | |||||
) if output_attentions and self.config.add_cross_attention else None | |||||
next_decoder_cache = () if use_cache else None | |||||
for i, layer_module in enumerate(self.layer): | |||||
if output_hidden_states: | |||||
all_hidden_states = all_hidden_states + (hidden_states, ) | |||||
layer_head_mask = head_mask[i] if head_mask is not None else None | |||||
past_key_value = past_key_values[ | |||||
i] if past_key_values is not None else None | |||||
if self.gradient_checkpointing and self.training: | |||||
if use_cache: | |||||
logger.warning( | |||||
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' | |||||
) | |||||
use_cache = False | |||||
def create_custom_forward(module): | |||||
def custom_forward(*inputs): | |||||
return module(*inputs, past_key_value, | |||||
output_attentions) | |||||
return custom_forward | |||||
layer_outputs = torch.utils.checkpoint.checkpoint( | |||||
create_custom_forward(layer_module), | |||||
hidden_states, | |||||
attention_mask, | |||||
layer_head_mask, | |||||
encoder_hidden_states, | |||||
encoder_attention_mask, | |||||
) | |||||
else: | |||||
layer_outputs = layer_module( | |||||
hidden_states, | |||||
attention_mask, | |||||
layer_head_mask, | |||||
encoder_hidden_states, | |||||
encoder_attention_mask, | |||||
past_key_value, | |||||
output_attentions, | |||||
) | |||||
hidden_states = layer_outputs[0] | |||||
if use_cache: | |||||
next_decoder_cache += (layer_outputs[-1], ) | |||||
if output_attentions: | |||||
all_self_attentions = all_self_attentions + ( | |||||
layer_outputs[1], ) | |||||
if self.config.add_cross_attention: | |||||
all_cross_attentions = all_cross_attentions + ( | |||||
layer_outputs[2], ) | |||||
if output_hidden_states: | |||||
all_hidden_states = all_hidden_states + (hidden_states, ) | |||||
if not return_dict: | |||||
return tuple(v for v in [ | |||||
hidden_states, | |||||
next_decoder_cache, | |||||
all_hidden_states, | |||||
all_self_attentions, | |||||
all_cross_attentions, | |||||
] if v is not None) | |||||
return BaseModelOutputWithPastAndCrossAttentions( | |||||
last_hidden_state=hidden_states, | |||||
past_key_values=next_decoder_cache, | |||||
hidden_states=all_hidden_states, | |||||
attentions=all_self_attentions, | |||||
cross_attentions=all_cross_attentions, | |||||
) | |||||
class SbertPooler(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||||
self.activation = nn.Tanh() | |||||
def forward(self, hidden_states): | |||||
# We "pool" the model by simply taking the hidden state corresponding | |||||
# to the first token. | |||||
first_token_tensor = hidden_states[:, 0] | |||||
pooled_output = self.dense(first_token_tensor) | |||||
pooled_output = self.activation(pooled_output) | |||||
return pooled_output | |||||
@dataclass | |||||
class SbertForPreTrainingOutput(ModelOutput): | |||||
""" | |||||
Output type of :class:`~structbert.utils.BertForPreTraining`. | |||||
Args: | |||||
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): | |||||
Total loss as the sum of the masked language modeling loss and the next sequence prediction | |||||
(classification) loss. | |||||
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): | |||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |||||
seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): | |||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation | |||||
before SoftMax). | |||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when | |||||
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | |||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when | |||||
``output_attentions=True`` is passed or when ``config.output_attentions=True``): | |||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, | |||||
sequence_length, sequence_length)`. | |||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |||||
heads. | |||||
""" | |||||
loss: Optional[torch.FloatTensor] = None | |||||
prediction_logits: torch.FloatTensor = None | |||||
seq_relationship_logits: torch.FloatTensor = None | |||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |||||
attentions: Optional[Tuple[torch.FloatTensor]] = None | |||||
@dataclass | |||||
class BaseModelOutputWithPoolingAndCrossAttentionsWithEmbedding( | |||||
BaseModelOutputWithPoolingAndCrossAttentions): | |||||
embedding_output: torch.FloatTensor = None | |||||
logits: Optional[Union[tuple, torch.FloatTensor]] = None | |||||
kwargs: dict = None |
@@ -0,0 +1,3 @@ | |||||
from .sequence_classification_head import SequenceClassificationHead | |||||
__all__ = ['SequenceClassificationHead'] |
@@ -0,0 +1,44 @@ | |||||
import importlib | |||||
from typing import Dict, List, Optional, Union | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
from ....metainfo import Heads | |||||
from ....outputs import OutputKeys | |||||
from ....utils.constant import Tasks | |||||
from ...base import TorchHead | |||||
from ...builder import HEADS | |||||
@HEADS.register_module( | |||||
Tasks.text_classification, module_name=Heads.text_classification) | |||||
class SequenceClassificationHead(TorchHead): | |||||
def __init__(self, **kwargs): | |||||
super().__init__(**kwargs) | |||||
config = self.config | |||||
self.num_labels = config.num_labels | |||||
self.config = config | |||||
classifier_dropout = ( | |||||
config['classifier_dropout'] if config.get('classifier_dropout') | |||||
is not None else config['hidden_dropout_prob']) | |||||
self.dropout = nn.Dropout(classifier_dropout) | |||||
self.classifier = nn.Linear(config['hidden_size'], | |||||
config['num_labels']) | |||||
def forward(self, inputs=None): | |||||
if isinstance(inputs, dict): | |||||
assert inputs.get('pooled_output') is not None | |||||
pooled_output = inputs.get('pooled_output') | |||||
else: | |||||
pooled_output = inputs | |||||
pooled_output = self.dropout(pooled_output) | |||||
logits = self.classifier(pooled_output) | |||||
return {OutputKeys.LOGITS: logits} | |||||
def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||||
labels) -> Dict[str, torch.Tensor]: | |||||
logits = outputs[OutputKeys.LOGITS] | |||||
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} |
@@ -2,8 +2,7 @@ from typing import Dict | |||||
from ...metainfo import Models | from ...metainfo import Models | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Tensor | |||||
from ..base_torch import TorchModel | |||||
from ..base import Tensor, TorchModel | |||||
from ..builder import MODELS | from ..builder import MODELS | ||||
__all__ = ['PalmForTextGeneration'] | __all__ = ['PalmForTextGeneration'] | ||||
@@ -42,6 +42,9 @@ class SbertTextClassfier(SbertPreTrainedModel): | |||||
return {'logits': logits, 'loss': loss} | return {'logits': logits, 'loss': loss} | ||||
return {'logits': logits} | return {'logits': logits} | ||||
def build(**kwags): | |||||
return SbertTextClassfier.from_pretrained(model_dir, **model_args) | |||||
class SbertForSequenceClassificationBase(Model): | class SbertForSequenceClassificationBase(Model): | ||||
@@ -0,0 +1,85 @@ | |||||
import os | |||||
from typing import Any, Dict | |||||
import json | |||||
import numpy as np | |||||
from ...metainfo import TaskModels | |||||
from ...outputs import OutputKeys | |||||
from ...utils.constant import Tasks | |||||
from ..builder import MODELS | |||||
from .task_model import SingleBackboneTaskModelBase | |||||
__all__ = ['SequenceClassificationModel'] | |||||
@MODELS.register_module( | |||||
Tasks.sentiment_classification, module_name=TaskModels.text_classification) | |||||
@MODELS.register_module( | |||||
Tasks.text_classification, module_name=TaskModels.text_classification) | |||||
class SequenceClassificationModel(SingleBackboneTaskModelBase): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
"""initialize the sequence classification model from the `model_dir` path. | |||||
Args: | |||||
model_dir (str): the model path. | |||||
""" | |||||
super().__init__(model_dir, *args, **kwargs) | |||||
if 'base_model_prefix' in kwargs: | |||||
self._base_model_prefix = kwargs['base_model_prefix'] | |||||
backbone_cfg = self.cfg.backbone | |||||
head_cfg = self.cfg.head | |||||
# get the num_labels from label_mapping.json | |||||
self.id2label = {} | |||||
self.label_path = os.path.join(model_dir, 'label_mapping.json') | |||||
if os.path.exists(self.label_path): | |||||
with open(self.label_path) as f: | |||||
self.label_mapping = json.load(f) | |||||
self.id2label = { | |||||
idx: name | |||||
for name, idx in self.label_mapping.items() | |||||
} | |||||
head_cfg['num_labels'] = len(self.label_mapping) | |||||
self.build_backbone(backbone_cfg) | |||||
self.build_head(head_cfg) | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
outputs = super().forward(input) | |||||
sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||||
outputs = self.head.forward(pooled_output) | |||||
if 'labels' in input: | |||||
loss = self.compute_loss(outputs, input['labels']) | |||||
outputs.update(loss) | |||||
return outputs | |||||
def extract_logits(self, outputs): | |||||
return outputs[OutputKeys.LOGITS].cpu().detach() | |||||
def extract_backbone_outputs(self, outputs): | |||||
sequence_output = None | |||||
pooled_output = None | |||||
if hasattr(self.backbone, 'extract_sequence_outputs'): | |||||
sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||||
if hasattr(self.backbone, 'extract_pooled_outputs'): | |||||
pooled_output = self.backbone.extract_pooled_outputs(outputs) | |||||
return sequence_output, pooled_output | |||||
def compute_loss(self, outputs, labels): | |||||
loss = self.head.compute_loss(outputs, labels) | |||||
return loss | |||||
def postprocess(self, input, **kwargs): | |||||
logits = self.extract_logits(input) | |||||
probs = logits.softmax(-1).numpy() | |||||
pred = logits.argmax(-1).numpy() | |||||
logits = logits.numpy() | |||||
res = { | |||||
OutputKeys.PREDICTIONS: pred, | |||||
OutputKeys.PROBABILITIES: probs, | |||||
OutputKeys.LOGITS: logits | |||||
} | |||||
return res |
@@ -3,15 +3,14 @@ | |||||
import os | import os | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from ....metainfo import Models | |||||
from ....preprocessors.space.fields.intent_field import IntentBPETextField | |||||
from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||||
from ....utils.config import Config | |||||
from ....utils.constant import ModelFile, Tasks | |||||
from ...base import Model, Tensor | |||||
from ...builder import MODELS | |||||
from .model.generator import Generator | |||||
from .model.model_base import SpaceModelBase | |||||
from ...metainfo import Models | |||||
from ...preprocessors.space.fields.intent_field import IntentBPETextField | |||||
from ...trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||||
from ...utils.config import Config | |||||
from ...utils.constant import ModelFile, Tasks | |||||
from ..base import Model, Tensor | |||||
from ..builder import MODELS | |||||
from .backbones import SpaceGenerator, SpaceModelBase | |||||
__all__ = ['SpaceForDialogIntent'] | __all__ = ['SpaceForDialogIntent'] | ||||
@@ -37,7 +36,8 @@ class SpaceForDialogIntent(Model): | |||||
'text_field', | 'text_field', | ||||
IntentBPETextField(self.model_dir, config=self.config)) | IntentBPETextField(self.model_dir, config=self.config)) | ||||
self.generator = Generator.create(self.config, reader=self.text_field) | |||||
self.generator = SpaceGenerator.create( | |||||
self.config, reader=self.text_field) | |||||
self.model = SpaceModelBase.create( | self.model = SpaceModelBase.create( | ||||
model_dir=model_dir, | model_dir=model_dir, | ||||
config=self.config, | config=self.config, |
@@ -3,15 +3,14 @@ | |||||
import os | import os | ||||
from typing import Any, Dict, Optional | from typing import Any, Dict, Optional | ||||
from ....metainfo import Models | |||||
from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField | |||||
from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||||
from ....utils.config import Config | |||||
from ....utils.constant import ModelFile, Tasks | |||||
from ...base import Model, Tensor | |||||
from ...builder import MODELS | |||||
from .model.generator import Generator | |||||
from .model.model_base import SpaceModelBase | |||||
from ...metainfo import Models | |||||
from ...preprocessors.space.fields.gen_field import MultiWOZBPETextField | |||||
from ...trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||||
from ...utils.config import Config | |||||
from ...utils.constant import ModelFile, Tasks | |||||
from ..base import Model, Tensor | |||||
from ..builder import MODELS | |||||
from .backbones import SpaceGenerator, SpaceModelBase | |||||
__all__ = ['SpaceForDialogModeling'] | __all__ = ['SpaceForDialogModeling'] | ||||
@@ -35,7 +34,8 @@ class SpaceForDialogModeling(Model): | |||||
self.text_field = kwargs.pop( | self.text_field = kwargs.pop( | ||||
'text_field', | 'text_field', | ||||
MultiWOZBPETextField(self.model_dir, config=self.config)) | MultiWOZBPETextField(self.model_dir, config=self.config)) | ||||
self.generator = Generator.create(self.config, reader=self.text_field) | |||||
self.generator = SpaceGenerator.create( | |||||
self.config, reader=self.text_field) | |||||
self.model = SpaceModelBase.create( | self.model = SpaceModelBase.create( | ||||
model_dir=model_dir, | model_dir=model_dir, | ||||
config=self.config, | config=self.config, |
@@ -2,10 +2,10 @@ import os | |||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
from ....metainfo import Models | |||||
from ....utils.nlp.space.utils_dst import batch_to_device | |||||
from ...base import Model, Tensor | |||||
from ...builder import MODELS | |||||
from ...metainfo import Models | |||||
from ...utils.nlp.space.utils_dst import batch_to_device | |||||
from ..base import Model, Tensor | |||||
from ..builder import MODELS | |||||
__all__ = ['SpaceForDialogStateTracking'] | __all__ = ['SpaceForDialogStateTracking'] | ||||
@@ -0,0 +1,489 @@ | |||||
import os.path | |||||
import re | |||||
from abc import ABC | |||||
from collections import OrderedDict | |||||
from typing import Any, Dict | |||||
import torch | |||||
from torch import nn | |||||
from ...utils.config import ConfigDict | |||||
from ...utils.constant import Fields, Tasks | |||||
from ...utils.logger import get_logger | |||||
from ...utils.utils import if_func_recieve_dict_inputs | |||||
from ..base import TorchModel | |||||
from ..builder import build_backbone, build_head | |||||
logger = get_logger(__name__) | |||||
__all__ = ['EncoderDecoderTaskModelBase', 'SingleBackboneTaskModelBase'] | |||||
def _repr(modules, depth=1): | |||||
# model name log level control | |||||
if depth == 0: | |||||
return modules._get_name() | |||||
# We treat the extra repr like the sub-module, one item per line | |||||
extra_lines = [] | |||||
extra_repr = modules.extra_repr() | |||||
# empty string will be split into list [''] | |||||
if extra_repr: | |||||
extra_lines = extra_repr.split('\n') | |||||
child_lines = [] | |||||
def _addindent(s_, numSpaces): | |||||
s = s_.split('\n') | |||||
# don't do anything for single-line stuff | |||||
if len(s) == 1: | |||||
return s_ | |||||
first = s.pop(0) | |||||
s = [(numSpaces * ' ') + line for line in s] | |||||
s = '\n'.join(s) | |||||
s = first + '\n' + s | |||||
return s | |||||
for key, module in modules._modules.items(): | |||||
mod_str = _repr(module, depth - 1) | |||||
mod_str = _addindent(mod_str, 2) | |||||
child_lines.append('(' + key + '): ' + mod_str) | |||||
lines = extra_lines + child_lines | |||||
main_str = modules._get_name() + '(' | |||||
if lines: | |||||
# simple one-liner info, which most builtin Modules will use | |||||
if len(extra_lines) == 1 and not child_lines: | |||||
main_str += extra_lines[0] | |||||
else: | |||||
main_str += '\n ' + '\n '.join(lines) + '\n' | |||||
main_str += ')' | |||||
return main_str | |||||
class BaseTaskModel(TorchModel, ABC): | |||||
""" Base task model interface for nlp | |||||
""" | |||||
# keys to ignore when load missing | |||||
_keys_to_ignore_on_load_missing = None | |||||
# keys to ignore when load unexpected | |||||
_keys_to_ignore_on_load_unexpected = None | |||||
# backbone prefix, default None | |||||
_backbone_prefix = None | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
super().__init__(model_dir, *args, **kwargs) | |||||
self.cfg = ConfigDict(kwargs) | |||||
def __repr__(self): | |||||
# only log backbone and head name | |||||
depth = 1 | |||||
return _repr(self, depth) | |||||
@classmethod | |||||
def _instantiate(cls, **kwargs): | |||||
model_dir = kwargs.get('model_dir') | |||||
model = cls(**kwargs) | |||||
model.load_checkpoint(model_local_dir=model_dir, **kwargs) | |||||
return model | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
pass | |||||
def load_checkpoint(self, | |||||
model_local_dir, | |||||
default_dtype=None, | |||||
load_state_fn=None, | |||||
**kwargs): | |||||
""" | |||||
Load model checkpoint file and feed the parameters into the model. | |||||
Args: | |||||
model_local_dir: The actual checkpoint dir on local disk. | |||||
default_dtype: Set the default float type by 'torch.set_default_dtype' | |||||
load_state_fn: An optional load_state_fn used to load state_dict into the model. | |||||
Returns: | |||||
""" | |||||
# TODO Sharded ckpt | |||||
ckpt_file = os.path.join(model_local_dir, 'pytorch_model.bin') | |||||
state_dict = torch.load(ckpt_file, map_location='cpu') | |||||
if default_dtype is not None: | |||||
torch.set_default_dtype(default_dtype) | |||||
missing_keys, unexpected_keys, mismatched_keys, error_msgs = self._load_checkpoint( | |||||
state_dict, | |||||
load_state_fn=load_state_fn, | |||||
ignore_mismatched_sizes=True, | |||||
_fast_init=True, | |||||
) | |||||
return { | |||||
'missing_keys': missing_keys, | |||||
'unexpected_keys': unexpected_keys, | |||||
'mismatched_keys': mismatched_keys, | |||||
'error_msgs': error_msgs, | |||||
} | |||||
def _load_checkpoint( | |||||
self, | |||||
state_dict, | |||||
load_state_fn, | |||||
ignore_mismatched_sizes, | |||||
_fast_init, | |||||
): | |||||
# Retrieve missing & unexpected_keys | |||||
model_state_dict = self.state_dict() | |||||
prefix = self._backbone_prefix | |||||
# add head prefix | |||||
new_state_dict = OrderedDict() | |||||
for name, module in state_dict.items(): | |||||
if not name.startswith(prefix) and not name.startswith('head'): | |||||
new_state_dict['.'.join(['head', name])] = module | |||||
else: | |||||
new_state_dict[name] = module | |||||
state_dict = new_state_dict | |||||
loaded_keys = [k for k in state_dict.keys()] | |||||
expected_keys = list(model_state_dict.keys()) | |||||
def _fix_key(key): | |||||
if 'beta' in key: | |||||
return key.replace('beta', 'bias') | |||||
if 'gamma' in key: | |||||
return key.replace('gamma', 'weight') | |||||
return key | |||||
original_loaded_keys = loaded_keys | |||||
loaded_keys = [_fix_key(key) for key in loaded_keys] | |||||
if len(prefix) > 0: | |||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) | |||||
expects_prefix_module = any( | |||||
s.startswith(prefix) for s in expected_keys) | |||||
else: | |||||
has_prefix_module = False | |||||
expects_prefix_module = False | |||||
# key re-naming operations are never done on the keys | |||||
# that are loaded, but always on the keys of the newly initialized model | |||||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module | |||||
add_prefix_to_model = has_prefix_module and not expects_prefix_module | |||||
if remove_prefix_from_model: | |||||
expected_keys_not_prefixed = [ | |||||
s for s in expected_keys if not s.startswith(prefix) | |||||
] | |||||
expected_keys = [ | |||||
'.'.join(s.split('.')[1:]) if s.startswith(prefix) else s | |||||
for s in expected_keys | |||||
] | |||||
elif add_prefix_to_model: | |||||
expected_keys = ['.'.join([prefix, s]) for s in expected_keys] | |||||
missing_keys = list(set(expected_keys) - set(loaded_keys)) | |||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys)) | |||||
if self._keys_to_ignore_on_load_missing is not None: | |||||
for pat in self._keys_to_ignore_on_load_missing: | |||||
missing_keys = [ | |||||
k for k in missing_keys if re.search(pat, k) is None | |||||
] | |||||
if self._keys_to_ignore_on_load_unexpected is not None: | |||||
for pat in self._keys_to_ignore_on_load_unexpected: | |||||
unexpected_keys = [ | |||||
k for k in unexpected_keys if re.search(pat, k) is None | |||||
] | |||||
if _fast_init: | |||||
# retrieve unintialized modules and initialize | |||||
uninitialized_modules = self.retrieve_modules_from_names( | |||||
missing_keys, | |||||
prefix=prefix, | |||||
add_prefix=add_prefix_to_model, | |||||
remove_prefix=remove_prefix_from_model) | |||||
for module in uninitialized_modules: | |||||
self._init_weights(module) | |||||
# Make sure we are able to load base models as well as derived models (with heads) | |||||
start_prefix = '' | |||||
model_to_load = self | |||||
if len(prefix) > 0 and not hasattr(self, prefix) and has_prefix_module: | |||||
start_prefix = prefix + '.' | |||||
if len(prefix) > 0 and hasattr(self, prefix) and not has_prefix_module: | |||||
model_to_load = getattr(self, prefix) | |||||
if any(key in expected_keys_not_prefixed for key in loaded_keys): | |||||
raise ValueError( | |||||
'The state dictionary of the model you are trying to load is corrupted. Are you sure it was ' | |||||
'properly saved?') | |||||
def _find_mismatched_keys( | |||||
state_dict, | |||||
model_state_dict, | |||||
loaded_keys, | |||||
add_prefix_to_model, | |||||
remove_prefix_from_model, | |||||
ignore_mismatched_sizes, | |||||
): | |||||
mismatched_keys = [] | |||||
if ignore_mismatched_sizes: | |||||
for checkpoint_key in loaded_keys: | |||||
model_key = checkpoint_key | |||||
if remove_prefix_from_model: | |||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | |||||
model_key = f'{prefix}.{checkpoint_key}' | |||||
elif add_prefix_to_model: | |||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | |||||
model_key = '.'.join(checkpoint_key.split('.')[1:]) | |||||
if (model_key in model_state_dict): | |||||
model_shape = model_state_dict[model_key].shape | |||||
checkpoint_shape = state_dict[checkpoint_key].shape | |||||
if (checkpoint_shape != model_shape): | |||||
mismatched_keys.append( | |||||
(checkpoint_key, | |||||
state_dict[checkpoint_key].shape, | |||||
model_state_dict[model_key].shape)) | |||||
del state_dict[checkpoint_key] | |||||
return mismatched_keys | |||||
def _load_state_dict_into_model(model_to_load, state_dict, | |||||
start_prefix): | |||||
# Convert old format to new format if needed from a PyTorch state_dict | |||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
if 'gamma' in key: | |||||
new_key = key.replace('gamma', 'weight') | |||||
if 'beta' in key: | |||||
new_key = key.replace('beta', 'bias') | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
# copy state_dict so _load_from_state_dict can modify it | |||||
metadata = getattr(state_dict, '_metadata', None) | |||||
state_dict = state_dict.copy() | |||||
if metadata is not None: | |||||
state_dict._metadata = metadata | |||||
error_msgs = [] | |||||
if load_state_fn is not None: | |||||
load_state_fn( | |||||
model_to_load, | |||||
state_dict, | |||||
prefix=start_prefix, | |||||
local_metadata=None, | |||||
error_msgs=error_msgs) | |||||
else: | |||||
def load(module: nn.Module, prefix=''): | |||||
local_metadata = {} if metadata is None else metadata.get( | |||||
prefix[:-1], {}) | |||||
args = (state_dict, prefix, local_metadata, True, [], [], | |||||
error_msgs) | |||||
module._load_from_state_dict(*args) | |||||
for name, child in module._modules.items(): | |||||
if child is not None: | |||||
load(child, prefix + name + '.') | |||||
load(model_to_load, prefix=start_prefix) | |||||
return error_msgs | |||||
# Whole checkpoint | |||||
mismatched_keys = _find_mismatched_keys( | |||||
state_dict, | |||||
model_state_dict, | |||||
original_loaded_keys, | |||||
add_prefix_to_model, | |||||
remove_prefix_from_model, | |||||
ignore_mismatched_sizes, | |||||
) | |||||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, | |||||
start_prefix) | |||||
if len(error_msgs) > 0: | |||||
error_msg = '\n\t'.join(error_msgs) | |||||
raise RuntimeError( | |||||
f'Error(s) in loading state_dict for {self.__class__.__name__}:\n\t{error_msg}' | |||||
) | |||||
if len(unexpected_keys) > 0: | |||||
logger.warning( | |||||
f'Some weights of the model checkpoint were not used when' | |||||
f' initializing {self.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are' | |||||
f' initializing {self.__class__.__name__} from the checkpoint of a model trained on another task or' | |||||
' with another architecture (e.g. initializing a BertForSequenceClassification model from a' | |||||
' BertForPreTraining model).\n- This IS NOT expected if you are initializing' | |||||
f' {self.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical' | |||||
' (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).' | |||||
) | |||||
else: | |||||
logger.info( | |||||
f'All model checkpoint weights were used when initializing {self.__class__.__name__}.\n' | |||||
) | |||||
if len(missing_keys) > 0: | |||||
logger.warning( | |||||
f'Some weights of {self.__class__.__name__} were not initialized from the model checkpoint' | |||||
f' and are newly initialized: {missing_keys}\nYou should probably' | |||||
' TRAIN this model on a down-stream task to be able to use it for predictions and inference.' | |||||
) | |||||
elif len(mismatched_keys) == 0: | |||||
logger.info( | |||||
f'All the weights of {self.__class__.__name__} were initialized from the model checkpoint ' | |||||
f'If your task is similar to the task the model of the checkpoint' | |||||
f' was trained on, you can already use {self.__class__.__name__} for predictions without further' | |||||
' training.') | |||||
if len(mismatched_keys) > 0: | |||||
mismatched_warning = '\n'.join([ | |||||
f'- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated' | |||||
for key, shape1, shape2 in mismatched_keys | |||||
]) | |||||
logger.warning( | |||||
f'Some weights of {self.__class__.__name__} were not initialized from the model checkpoint' | |||||
f' and are newly initialized because the shapes did not' | |||||
f' match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able' | |||||
' to use it for predictions and inference.') | |||||
return missing_keys, unexpected_keys, mismatched_keys, error_msgs | |||||
def retrieve_modules_from_names(self, | |||||
names, | |||||
prefix=None, | |||||
add_prefix=False, | |||||
remove_prefix=False): | |||||
module_keys = set(['.'.join(key.split('.')[:-1]) for key in names]) | |||||
# torch.nn.ParameterList is a special case where two parameter keywords | |||||
# are appended to the module name, *e.g.* bert.special_embeddings.0 | |||||
module_keys = module_keys.union( | |||||
set([ | |||||
'.'.join(key.split('.')[:-2]) for key in names | |||||
if key[-1].isdigit() | |||||
])) | |||||
retrieved_modules = [] | |||||
# retrieve all modules that has at least one missing weight name | |||||
for name, module in self.named_modules(): | |||||
if remove_prefix: | |||||
name = '.'.join( | |||||
name.split('.')[1:]) if name.startswith(prefix) else name | |||||
elif add_prefix: | |||||
name = '.'.join([prefix, name]) if len(name) > 0 else prefix | |||||
if name in module_keys: | |||||
retrieved_modules.append(module) | |||||
return retrieved_modules | |||||
class SingleBackboneTaskModelBase(BaseTaskModel): | |||||
""" | |||||
This is the base class of any single backbone nlp task classes. | |||||
""" | |||||
# The backbone prefix defaults to "bert" | |||||
_backbone_prefix = 'bert' | |||||
# The head prefix defaults to "head" | |||||
_head_prefix = 'head' | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
super().__init__(model_dir, *args, **kwargs) | |||||
def build_backbone(self, cfg): | |||||
if 'prefix' in cfg: | |||||
self._backbone_prefix = cfg['prefix'] | |||||
backbone = build_backbone(cfg, field=Fields.nlp) | |||||
setattr(self, cfg['prefix'], backbone) | |||||
def build_head(self, cfg): | |||||
if 'prefix' in cfg: | |||||
self._head_prefix = cfg['prefix'] | |||||
head = build_head(cfg) | |||||
setattr(self, self._head_prefix, head) | |||||
return head | |||||
@property | |||||
def backbone(self): | |||||
if 'backbone' != self._backbone_prefix: | |||||
return getattr(self, self._backbone_prefix) | |||||
return super().__getattr__('backbone') | |||||
@property | |||||
def head(self): | |||||
if 'head' != self._head_prefix: | |||||
return getattr(self, self._head_prefix) | |||||
return super().__getattr__('head') | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
"""default forward method is the backbone-only forward""" | |||||
if if_func_recieve_dict_inputs(self.backbone.forward, input): | |||||
outputs = self.backbone.forward(input) | |||||
else: | |||||
outputs = self.backbone.forward(**input) | |||||
return outputs | |||||
class EncoderDecoderTaskModelBase(BaseTaskModel): | |||||
""" | |||||
This is the base class of encoder-decoder nlp task classes. | |||||
""" | |||||
# The encoder backbone prefix, default to "encoder" | |||||
_encoder_prefix = 'encoder' | |||||
# The decoder backbone prefix, default to "decoder" | |||||
_decoder_prefix = 'decoder' | |||||
# The key in cfg specifing the encoder type | |||||
_encoder_key_in_cfg = 'encoder_type' | |||||
# The key in cfg specifing the decoder type | |||||
_decoder_key_in_cfg = 'decoder_type' | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
super().__init__(model_dir, *args, **kwargs) | |||||
def build_encoder(self): | |||||
encoder = build_backbone( | |||||
self.cfg, | |||||
type_name=self._encoder_key_in_cfg, | |||||
task_name=Tasks.backbone) | |||||
setattr(self, self._encoder_prefix, encoder) | |||||
return encoder | |||||
def build_decoder(self): | |||||
decoder = build_backbone( | |||||
self.cfg, | |||||
type_name=self._decoder_key_in_cfg, | |||||
task_name=Tasks.backbone) | |||||
setattr(self, self._decoder_prefix, decoder) | |||||
return decoder | |||||
@property | |||||
def encoder_(self): | |||||
return getattr(self, self._encoder_prefix) | |||||
@property | |||||
def decoder_(self): | |||||
return getattr(self, self._decoder_prefix) | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
if if_func_recieve_dict_inputs(self.encoder_.forward, input): | |||||
encoder_outputs = self.encoder_.forward(input) | |||||
else: | |||||
encoder_outputs = self.encoder_.forward(**input) | |||||
decoder_inputs = self.project_decoder_inputs_and_mediate( | |||||
input, encoder_outputs) | |||||
if if_func_recieve_dict_inputs(self.decoder_.forward, input): | |||||
outputs = self.decoder_.forward(decoder_inputs) | |||||
else: | |||||
outputs = self.decoder_.forward(**decoder_inputs) | |||||
return outputs | |||||
def project_decoder_inputs_and_mediate(self, input, encoder_outputs): | |||||
return {**input, **encoder_outputs} |
@@ -4,6 +4,7 @@ from modelscope.utils.constant import Tasks | |||||
class OutputKeys(object): | class OutputKeys(object): | ||||
LOSS = 'loss' | |||||
LOGITS = 'logits' | LOGITS = 'logits' | ||||
SCORES = 'scores' | SCORES = 'scores' | ||||
LABEL = 'label' | LABEL = 'label' | ||||
@@ -22,6 +23,8 @@ class OutputKeys(object): | |||||
TRANSLATION = 'translation' | TRANSLATION = 'translation' | ||||
RESPONSE = 'response' | RESPONSE = 'response' | ||||
PREDICTION = 'prediction' | PREDICTION = 'prediction' | ||||
PREDICTIONS = 'predictions' | |||||
PROBABILITIES = 'probabilities' | |||||
DIALOG_STATES = 'dialog_states' | DIALOG_STATES = 'dialog_states' | ||||
VIDEO_EMBEDDING = 'video_embedding' | VIDEO_EMBEDDING = 'video_embedding' | ||||
@@ -30,7 +30,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | ||||
Tasks.sentiment_classification: | Tasks.sentiment_classification: | ||||
(Pipelines.sentiment_classification, | (Pipelines.sentiment_classification, | ||||
'damo/nlp_structbert_sentiment-classification_chinese-base'), | |||||
'damo/nlp_structbert_sentiment-classification_chinese-base' | |||||
), # TODO: revise back after passing the pr | |||||
Tasks.image_matting: (Pipelines.image_matting, | Tasks.image_matting: (Pipelines.image_matting, | ||||
'damo/cv_unet_image-matting'), | 'damo/cv_unet_image-matting'), | ||||
Tasks.text_classification: (Pipelines.sentiment_analysis, | Tasks.text_classification: (Pipelines.sentiment_analysis, | ||||
@@ -2,10 +2,10 @@ | |||||
from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
from modelscope.outputs import OutputKeys | |||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models import Model | from ...models import Model | ||||
from ...models.nlp import SpaceForDialogIntent | from ...models.nlp import SpaceForDialogIntent | ||||
from ...outputs import OutputKeys | |||||
from ...preprocessors import DialogIntentPredictionPreprocessor | from ...preprocessors import DialogIntentPredictionPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Pipeline | from ..base import Pipeline | ||||
@@ -2,10 +2,10 @@ | |||||
from typing import Dict, Union | from typing import Dict, Union | ||||
from modelscope.outputs import OutputKeys | |||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models import Model | from ...models import Model | ||||
from ...models.nlp import SpaceForDialogModeling | from ...models.nlp import SpaceForDialogModeling | ||||
from ...outputs import OutputKeys | |||||
from ...preprocessors import DialogModelingPreprocessor | from ...preprocessors import DialogModelingPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Pipeline, Tensor | from ..base import Pipeline, Tensor | ||||
@@ -1,8 +1,8 @@ | |||||
from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
from modelscope.outputs import OutputKeys | |||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models import Model, SpaceForDialogStateTracking | from ...models import Model, SpaceForDialogStateTracking | ||||
from ...outputs import OutputKeys | |||||
from ...preprocessors import DialogStateTrackingPreprocessor | from ...preprocessors import DialogStateTrackingPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Pipeline | from ..base import Pipeline | ||||
@@ -6,7 +6,7 @@ import torch | |||||
from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
from ...models import Model | from ...models import Model | ||||
from ...models.nlp import SbertForSentimentClassification | |||||
from ...models.nlp import SequenceClassificationModel | |||||
from ...preprocessors import SentimentClassificationPreprocessor | from ...preprocessors import SentimentClassificationPreprocessor | ||||
from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
from ..base import Pipeline | from ..base import Pipeline | ||||
@@ -21,7 +21,7 @@ __all__ = ['SentimentClassificationPipeline'] | |||||
class SentimentClassificationPipeline(Pipeline): | class SentimentClassificationPipeline(Pipeline): | ||||
def __init__(self, | def __init__(self, | ||||
model: Union[SbertForSentimentClassification, str], | |||||
model: Union[SequenceClassificationModel, str], | |||||
preprocessor: SentimentClassificationPreprocessor = None, | preprocessor: SentimentClassificationPreprocessor = None, | ||||
first_sequence='first_sequence', | first_sequence='first_sequence', | ||||
second_sequence='second_sequence', | second_sequence='second_sequence', | ||||
@@ -29,14 +29,14 @@ class SentimentClassificationPipeline(Pipeline): | |||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
Args: | Args: | ||||
model (SbertForSentimentClassification): a model instance | |||||
model (SequenceClassificationModel): a model instance | |||||
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | ||||
""" | """ | ||||
assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \ | |||||
'model must be a single str or SbertForSentimentClassification' | |||||
assert isinstance(model, str) or isinstance(model, SequenceClassificationModel), \ | |||||
'model must be a single str or SentimentClassification' | |||||
model = model if isinstance( | model = model if isinstance( | ||||
model, | model, | ||||
SbertForSentimentClassification) else Model.from_pretrained(model) | |||||
SequenceClassificationModel) else Model.from_pretrained(model) | |||||
if preprocessor is None: | if preprocessor is None: | ||||
preprocessor = SentimentClassificationPreprocessor( | preprocessor = SentimentClassificationPreprocessor( | ||||
model.model_dir, | model.model_dir, | ||||
@@ -3,12 +3,23 @@ | |||||
from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from modelscope.utils.constant import ModeKeys | |||||
class Preprocessor(ABC): | class Preprocessor(ABC): | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
self._mode = ModeKeys.INFERENCE | |||||
pass | pass | ||||
@abstractmethod | @abstractmethod | ||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
pass | pass | ||||
@property | |||||
def mode(self): | |||||
return self._mode | |||||
@mode.setter | |||||
def mode(self, value): | |||||
self._mode = value |
@@ -7,7 +7,7 @@ from transformers import AutoTokenizer | |||||
from ..metainfo import Preprocessors | from ..metainfo import Preprocessors | ||||
from ..models import Model | from ..models import Model | ||||
from ..utils.constant import Fields, InputFields | |||||
from ..utils.constant import Fields, InputFields, ModeKeys | |||||
from ..utils.hub import parse_label_mapping | from ..utils.hub import parse_label_mapping | ||||
from ..utils.type_assert import type_assert | from ..utils.type_assert import type_assert | ||||
from .base import Preprocessor | from .base import Preprocessor | ||||
@@ -52,6 +52,7 @@ class NLPPreprocessorBase(Preprocessor): | |||||
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | ||||
self.tokenize_kwargs = kwargs | self.tokenize_kwargs = kwargs | ||||
self.tokenizer = self.build_tokenizer(model_dir) | self.tokenizer = self.build_tokenizer(model_dir) | ||||
self.label2id = parse_label_mapping(self.model_dir) | |||||
def build_tokenizer(self, model_dir): | def build_tokenizer(self, model_dir): | ||||
from sofa import SbertTokenizer | from sofa import SbertTokenizer | ||||
@@ -83,7 +84,12 @@ class NLPPreprocessorBase(Preprocessor): | |||||
text_a = data.get(self.first_sequence) | text_a = data.get(self.first_sequence) | ||||
text_b = data.get(self.second_sequence, None) | text_b = data.get(self.second_sequence, None) | ||||
return self.tokenizer(text_a, text_b, **self.tokenize_kwargs) | |||||
rst = self.tokenizer(text_a, text_b, **self.tokenize_kwargs) | |||||
if self._mode == ModeKeys.TRAIN: | |||||
rst = {k: v.squeeze() for k, v in rst.items()} | |||||
if self.label2id is not None and 'label' in data: | |||||
rst['label'] = self.label2id[str(data['label'])] | |||||
return rst | |||||
@PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
@@ -200,16 +206,6 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
kwargs['padding'] = 'max_length' | kwargs['padding'] = 'max_length' | ||||
super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
self.label2id = parse_label_mapping(self.model_dir) | |||||
@type_assert(object, (str, tuple, Dict)) | |||||
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: | |||||
rst = super().__call__(data) | |||||
rst = {k: v.squeeze() for k, v in rst.items()} | |||||
if self.label2id is not None and 'label' in data: | |||||
rst['labels'] = [] | |||||
rst['labels'].append(self.label2id[str(data['label'])]) | |||||
return rst | |||||
@PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
@@ -2,6 +2,7 @@ | |||||
import os.path | import os.path | ||||
import random | import random | ||||
import time | import time | ||||
from collections.abc import Mapping | |||||
from distutils.version import LooseVersion | from distutils.version import LooseVersion | ||||
from functools import partial | from functools import partial | ||||
from typing import Callable, List, Optional, Tuple, Union | from typing import Callable, List, Optional, Tuple, Union | ||||
@@ -16,8 +17,7 @@ from torch.utils.data.distributed import DistributedSampler | |||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.metrics import build_metric, task_default_metrics | from modelscope.metrics import build_metric, task_default_metrics | ||||
from modelscope.models.base import Model | |||||
from modelscope.models.base_torch import TorchModel | |||||
from modelscope.models.base import Model, TorchModel | |||||
from modelscope.msdatasets.ms_dataset import MsDataset | from modelscope.msdatasets.ms_dataset import MsDataset | ||||
from modelscope.preprocessors import build_preprocessor | from modelscope.preprocessors import build_preprocessor | ||||
from modelscope.preprocessors.base import Preprocessor | from modelscope.preprocessors.base import Preprocessor | ||||
@@ -26,12 +26,13 @@ from modelscope.trainers.hooks.priority import Priority, get_priority | |||||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | ||||
from modelscope.trainers.optimizer.builder import build_optimizer | from modelscope.trainers.optimizer.builder import build_optimizer | ||||
from modelscope.utils.config import Config, ConfigDict | from modelscope.utils.config import Config, ConfigDict | ||||
from modelscope.utils.constant import (Hubs, ModeKeys, ModelFile, Tasks, | |||||
TrainerStages) | |||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Hubs, ModeKeys, | |||||
ModelFile, Tasks, TrainerStages) | |||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from modelscope.utils.registry import build_from_cfg | from modelscope.utils.registry import build_from_cfg | ||||
from modelscope.utils.tensor_utils import torch_default_data_collator | from modelscope.utils.tensor_utils import torch_default_data_collator | ||||
from modelscope.utils.torch_utils import get_dist_info | from modelscope.utils.torch_utils import get_dist_info | ||||
from modelscope.utils.utils import if_func_recieve_dict_inputs | |||||
from .base import BaseTrainer | from .base import BaseTrainer | ||||
from .builder import TRAINERS | from .builder import TRAINERS | ||||
from .default_config import DEFAULT_CONFIG | from .default_config import DEFAULT_CONFIG | ||||
@@ -79,13 +80,15 @@ class EpochBasedTrainer(BaseTrainer): | |||||
optimizers: Tuple[torch.optim.Optimizer, | optimizers: Tuple[torch.optim.Optimizer, | ||||
torch.optim.lr_scheduler._LRScheduler] = (None, | torch.optim.lr_scheduler._LRScheduler] = (None, | ||||
None), | None), | ||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
**kwargs): | **kwargs): | ||||
if isinstance(model, str): | if isinstance(model, str): | ||||
if os.path.exists(model): | if os.path.exists(model): | ||||
self.model_dir = model if os.path.isdir( | self.model_dir = model if os.path.isdir( | ||||
model) else os.path.dirname(model) | model) else os.path.dirname(model) | ||||
else: | else: | ||||
self.model_dir = snapshot_download(model) | |||||
self.model_dir = snapshot_download( | |||||
model, revision=model_revision) | |||||
cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION) | cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION) | ||||
self.model = self.build_model() | self.model = self.build_model() | ||||
else: | else: | ||||
@@ -112,6 +115,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
self.preprocessor = preprocessor | self.preprocessor = preprocessor | ||||
elif hasattr(self.cfg, 'preprocessor'): | elif hasattr(self.cfg, 'preprocessor'): | ||||
self.preprocessor = self.build_preprocessor() | self.preprocessor = self.build_preprocessor() | ||||
if self.preprocessor is not None: | |||||
self.preprocessor.mode = ModeKeys.TRAIN | |||||
# TODO @wenmeng.zwm add data collator option | # TODO @wenmeng.zwm add data collator option | ||||
# TODO how to fill device option? | # TODO how to fill device option? | ||||
self.device = int( | self.device = int( | ||||
@@ -264,7 +269,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
model = Model.from_pretrained(self.model_dir) | model = Model.from_pretrained(self.model_dir) | ||||
if not isinstance(model, nn.Module) and hasattr(model, 'model'): | if not isinstance(model, nn.Module) and hasattr(model, 'model'): | ||||
return model.model | return model.model | ||||
return model | |||||
elif isinstance(model, nn.Module): | |||||
return model | |||||
def collate_fn(self, data): | def collate_fn(self, data): | ||||
"""Prepare the input just before the forward function. | """Prepare the input just before the forward function. | ||||
@@ -307,7 +313,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
model.train() | model.train() | ||||
self._mode = ModeKeys.TRAIN | self._mode = ModeKeys.TRAIN | ||||
inputs = self.collate_fn(inputs) | inputs = self.collate_fn(inputs) | ||||
if not isinstance(model, Model) and isinstance(inputs, dict): | |||||
if isinstance(inputs, Mapping) and not if_func_recieve_dict_inputs( | |||||
model.forward, inputs): | |||||
train_outputs = model.forward(**inputs) | train_outputs = model.forward(**inputs) | ||||
else: | else: | ||||
train_outputs = model.forward(inputs) | train_outputs = model.forward(inputs) | ||||
@@ -5,6 +5,7 @@ import pickle | |||||
import shutil | import shutil | ||||
import tempfile | import tempfile | ||||
import time | import time | ||||
from collections.abc import Mapping | |||||
import torch | import torch | ||||
from torch import distributed as dist | from torch import distributed as dist | ||||
@@ -12,6 +13,7 @@ from tqdm import tqdm | |||||
from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
from modelscope.utils.torch_utils import get_dist_info | from modelscope.utils.torch_utils import get_dist_info | ||||
from modelscope.utils.utils import if_func_recieve_dict_inputs | |||||
def single_gpu_test(model, | def single_gpu_test(model, | ||||
@@ -36,7 +38,10 @@ def single_gpu_test(model, | |||||
if data_collate_fn is not None: | if data_collate_fn is not None: | ||||
data = data_collate_fn(data) | data = data_collate_fn(data) | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if not isinstance(model, Model): | |||||
if isinstance(data, | |||||
Mapping) and not if_func_recieve_dict_inputs( | |||||
model.forward, data): | |||||
result = model(**data) | result = model(**data) | ||||
else: | else: | ||||
result = model(data) | result = model(data) | ||||
@@ -87,7 +92,9 @@ def multi_gpu_test(model, | |||||
if data_collate_fn is not None: | if data_collate_fn is not None: | ||||
data = data_collate_fn(data) | data = data_collate_fn(data) | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if not isinstance(model, Model): | |||||
if isinstance(data, | |||||
Mapping) and not if_func_recieve_dict_inputs( | |||||
model.forward, data): | |||||
result = model(**data) | result = model(**data) | ||||
else: | else: | ||||
result = model(data) | result = model(data) | ||||
@@ -57,6 +57,7 @@ class NLPTasks(object): | |||||
summarization = 'summarization' | summarization = 'summarization' | ||||
question_answering = 'question-answering' | question_answering = 'question-answering' | ||||
zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
backbone = 'backbone' | |||||
class AudioTasks(object): | class AudioTasks(object): | ||||
@@ -173,6 +174,7 @@ DEFAULT_DATASET_REVISION = 'master' | |||||
class ModeKeys: | class ModeKeys: | ||||
TRAIN = 'train' | TRAIN = 'train' | ||||
EVAL = 'eval' | EVAL = 'eval' | ||||
INFERENCE = 'inference' | |||||
class LogKeys: | class LogKeys: | ||||
@@ -6,6 +6,7 @@ from typing import List, Tuple, Union | |||||
from modelscope.utils.import_utils import requires | from modelscope.utils.import_utils import requires | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
TYPE_NAME = 'type' | |||||
default_group = 'default' | default_group = 'default' | ||||
logger = get_logger() | logger = get_logger() | ||||
@@ -159,15 +160,16 @@ def build_from_cfg(cfg, | |||||
group_key (str, optional): The name of registry group from which | group_key (str, optional): The name of registry group from which | ||||
module should be searched. | module should be searched. | ||||
default_args (dict, optional): Default initialization arguments. | default_args (dict, optional): Default initialization arguments. | ||||
type_name (str, optional): The name of the type in the config. | |||||
Returns: | Returns: | ||||
object: The constructed object. | object: The constructed object. | ||||
""" | """ | ||||
if not isinstance(cfg, dict): | if not isinstance(cfg, dict): | ||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | ||||
if 'type' not in cfg: | |||||
if default_args is None or 'type' not in default_args: | |||||
if TYPE_NAME not in cfg: | |||||
if default_args is None or TYPE_NAME not in default_args: | |||||
raise KeyError( | raise KeyError( | ||||
'`cfg` or `default_args` must contain the key "type", ' | |||||
f'`cfg` or `default_args` must contain the key "{TYPE_NAME}", ' | |||||
f'but got {cfg}\n{default_args}') | f'but got {cfg}\n{default_args}') | ||||
if not isinstance(registry, Registry): | if not isinstance(registry, Registry): | ||||
raise TypeError('registry must be an modelscope.Registry object, ' | raise TypeError('registry must be an modelscope.Registry object, ' | ||||
@@ -184,7 +186,7 @@ def build_from_cfg(cfg, | |||||
if group_key is None: | if group_key is None: | ||||
group_key = default_group | group_key = default_group | ||||
obj_type = args.pop('type') | |||||
obj_type = args.pop(TYPE_NAME) | |||||
if isinstance(obj_type, str): | if isinstance(obj_type, str): | ||||
obj_cls = registry.get(obj_type, group_key=group_key) | obj_cls = registry.get(obj_type, group_key=group_key) | ||||
if obj_cls is None: | if obj_cls is None: | ||||
@@ -196,7 +198,10 @@ def build_from_cfg(cfg, | |||||
raise TypeError( | raise TypeError( | ||||
f'type must be a str or valid type, but got {type(obj_type)}') | f'type must be a str or valid type, but got {type(obj_type)}') | ||||
try: | try: | ||||
return obj_cls(**args) | |||||
if hasattr(obj_cls, '_instantiate'): | |||||
return obj_cls._instantiate(**args) | |||||
else: | |||||
return obj_cls(**args) | |||||
except Exception as e: | except Exception as e: | ||||
# Normal TypeError does not print class name. | # Normal TypeError does not print class name. | ||||
raise type(e)(f'{obj_cls.__name__}: {e}') | raise type(e)(f'{obj_cls.__name__}: {e}') |
@@ -2,6 +2,8 @@ | |||||
# Part of the implementation is borrowed from huggingface/transformers. | # Part of the implementation is borrowed from huggingface/transformers. | ||||
from collections.abc import Mapping | from collections.abc import Mapping | ||||
import numpy as np | |||||
def torch_nested_numpify(tensors): | def torch_nested_numpify(tensors): | ||||
import torch | import torch | ||||
@@ -27,9 +29,6 @@ def torch_nested_detach(tensors): | |||||
def torch_default_data_collator(features): | def torch_default_data_collator(features): | ||||
# TODO @jiangnana.jnn refine this default data collator | # TODO @jiangnana.jnn refine this default data collator | ||||
import torch | import torch | ||||
# if not isinstance(features[0], (dict, BatchEncoding)): | |||||
# features = [vars(f) for f in features] | |||||
first = features[0] | first = features[0] | ||||
if isinstance(first, Mapping): | if isinstance(first, Mapping): | ||||
@@ -40,9 +39,14 @@ def torch_default_data_collator(features): | |||||
if 'label' in first and first['label'] is not None: | if 'label' in first and first['label'] is not None: | ||||
label = first['label'].item() if isinstance( | label = first['label'].item() if isinstance( | ||||
first['label'], torch.Tensor) else first['label'] | first['label'], torch.Tensor) else first['label'] | ||||
dtype = torch.long if isinstance(label, int) else torch.float | |||||
batch['labels'] = torch.tensor([f['label'] for f in features], | |||||
dtype=dtype) | |||||
# the msdataset return a 0-dimension np.array with a single value, the following part handle this. | |||||
if isinstance(label, np.ndarray): | |||||
dtype = torch.long if label[( | |||||
)].dtype == np.int64 else torch.float | |||||
else: | |||||
dtype = torch.long if isinstance(label, int) else torch.float | |||||
batch['labels'] = torch.tensor( | |||||
np.array([f['label'] for f in features]), dtype=dtype) | |||||
elif 'label_ids' in first and first['label_ids'] is not None: | elif 'label_ids' in first and first['label_ids'] is not None: | ||||
if isinstance(first['label_ids'], torch.Tensor): | if isinstance(first['label_ids'], torch.Tensor): | ||||
batch['labels'] = torch.stack( | batch['labels'] = torch.stack( | ||||
@@ -0,0 +1,28 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import inspect | |||||
def if_func_recieve_dict_inputs(func, inputs): | |||||
"""to decide if a func could recieve dict inputs or not | |||||
Args: | |||||
func (class): the target function to be inspected | |||||
inputs (dicts): the inputs that will send to the function | |||||
Returns: | |||||
bool: if func recieve dict, then recieve True | |||||
Examples: | |||||
input = {"input_dict":xxx, "attention_masked":xxx}, | |||||
function(self, inputs) then return True | |||||
function(inputs) then return True | |||||
function(self, input_dict, attention_masked) then return False | |||||
""" | |||||
signature = inspect.signature(func) | |||||
func_inputs = list(signature.parameters.keys() - set(['self'])) | |||||
mismatched_inputs = list(set(func_inputs) - set(inputs)) | |||||
if len(func_inputs) == len(mismatched_inputs): | |||||
return True | |||||
else: | |||||
return False |
@@ -7,7 +7,7 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from modelscope.models.base_torch import TorchModel | |||||
from modelscope.models.base import TorchModel | |||||
class TorchBaseTest(unittest.TestCase): | class TorchBaseTest(unittest.TestCase): | ||||
@@ -3,7 +3,8 @@ import unittest | |||||
from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
from modelscope.models import Model | from modelscope.models import Model | ||||
from modelscope.models.nlp import SbertForSentimentClassification | |||||
from modelscope.models.nlp import (SbertForSentimentClassification, | |||||
SequenceClassificationModel) | |||||
from modelscope.pipelines import SentimentClassificationPipeline, pipeline | from modelscope.pipelines import SentimentClassificationPipeline, pipeline | ||||
from modelscope.preprocessors import SentimentClassificationPreprocessor | from modelscope.preprocessors import SentimentClassificationPreprocessor | ||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
@@ -18,39 +19,44 @@ class SentimentClassificationTest(unittest.TestCase): | |||||
def test_run_with_direct_file_download(self): | def test_run_with_direct_file_download(self): | ||||
cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
tokenizer = SentimentClassificationPreprocessor(cache_path) | tokenizer = SentimentClassificationPreprocessor(cache_path) | ||||
model = SbertForSentimentClassification( | |||||
cache_path, tokenizer=tokenizer) | |||||
model = SequenceClassificationModel.from_pretrained( | |||||
self.model_id, num_labels=2) | |||||
pipeline1 = SentimentClassificationPipeline( | pipeline1 = SentimentClassificationPipeline( | ||||
model, preprocessor=tokenizer) | model, preprocessor=tokenizer) | ||||
pipeline2 = pipeline( | pipeline2 = pipeline( | ||||
Tasks.sentiment_classification, | Tasks.sentiment_classification, | ||||
model=model, | model=model, | ||||
preprocessor=tokenizer) | |||||
preprocessor=tokenizer, | |||||
model_revision='beta') | |||||
print(f'sentence1: {self.sentence1}\n' | print(f'sentence1: {self.sentence1}\n' | ||||
f'pipeline1:{pipeline1(input=self.sentence1)}') | f'pipeline1:{pipeline1(input=self.sentence1)}') | ||||
print() | print() | ||||
print(f'sentence1: {self.sentence1}\n' | print(f'sentence1: {self.sentence1}\n' | ||||
f'pipeline1: {pipeline2(input=self.sentence1)}') | f'pipeline1: {pipeline2(input=self.sentence1)}') | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
tokenizer = SentimentClassificationPreprocessor(model.model_dir) | tokenizer = SentimentClassificationPreprocessor(model.model_dir) | ||||
pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
task=Tasks.sentiment_classification, | task=Tasks.sentiment_classification, | ||||
model=model, | model=model, | ||||
preprocessor=tokenizer) | |||||
preprocessor=tokenizer, | |||||
model_revision='beta') | |||||
print(pipeline_ins(input=self.sentence1)) | print(pipeline_ins(input=self.sentence1)) | ||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_name(self): | def test_run_with_model_name(self): | ||||
pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
task=Tasks.sentiment_classification, model=self.model_id) | |||||
task=Tasks.sentiment_classification, | |||||
model=self.model_id, | |||||
model_revision='beta') | |||||
print(pipeline_ins(input=self.sentence1)) | print(pipeline_ins(input=self.sentence1)) | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_default_model(self): | def test_run_with_default_model(self): | ||||
pipeline_ins = pipeline(task=Tasks.sentiment_classification) | |||||
pipeline_ins = pipeline( | |||||
task=Tasks.sentiment_classification, model_revision='beta') | |||||
print(pipeline_ins(input=self.sentence1)) | print(pipeline_ins(input=self.sentence1)) | ||||
@@ -56,6 +56,23 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
for i in range(10): | for i in range(10): | ||||
self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_trainer_with_backbone_head(self): | |||||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||||
kwargs = dict( | |||||
model=model_id, | |||||
train_dataset=self.dataset, | |||||
eval_dataset=self.dataset, | |||||
work_dir=self.tmp_dir, | |||||
model_revision='beta') | |||||
trainer = build_trainer(default_args=kwargs) | |||||
trainer.train() | |||||
results_files = os.listdir(self.tmp_dir) | |||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
for i in range(10): | |||||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
def test_trainer_with_model_and_args(self): | def test_trainer_with_model_and_args(self): | ||||
tmp_dir = tempfile.TemporaryDirectory().name | tmp_dir = tempfile.TemporaryDirectory().name | ||||