Browse Source

fix token classification bugs

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10385225

    * fix token classification bugs
master
zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
14e52b308a
4 changed files with 47 additions and 26 deletions
  1. +2
    -10
      modelscope/models/nlp/bert/__init__.py
  2. +12
    -13
      modelscope/models/nlp/bert/modeling_bert.py
  3. +27
    -2
      modelscope/models/nlp/token_classification.py
  4. +6
    -1
      modelscope/pipelines/nlp/token_classification_pipeline.py

+ 2
- 10
modelscope/models/nlp/bert/__init__.py View File

@@ -5,7 +5,6 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .modeling_bert import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
@@ -20,21 +19,14 @@ if TYPE_CHECKING:
load_tf_weights_in_bert,
)

from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
from .tokenization_bert_fast import BertTokenizerFast
from .configuration_bert import BertConfig, BertOnnxConfig

else:
_import_structure = {
'configuration_bert':
['BERT_PRETRAINED_CONFIG_ARCHIVE_MAP', 'BertConfig', 'BertOnnxConfig'],
'tokenization_bert':
['BasicTokenizer', 'BertTokenizer', 'WordpieceTokenizer'],
'configuration_bert': ['BertConfig', 'BertOnnxConfig'],
}
_import_structure['tokenization_bert_fast'] = ['BertTokenizerFast']

_import_structure['modeling_bert'] = [
'BERT_PRETRAINED_MODEL_ARCHIVE_LIST',
'BertForMaskedLM',
'BertForMultipleChoice',
'BertForNextSentencePrediction',


+ 12
- 13
modelscope/models/nlp/bert/modeling_bert.py View File

@@ -1872,19 +1872,18 @@ class BertForTokenClassification(BertPreTrainedModel):

@add_start_docstrings_to_model_forward(
BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length'))
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`,
*optional*):


+ 27
- 2
modelscope/models/nlp/token_classification.py View File

@@ -176,7 +176,7 @@ class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel):

@MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert)
@MODELS.register_module(Tasks.token_classification, module_name=Models.bert)
class BertForSequenceClassification(TokenClassification, BertPreTrainedModel):
class BertForTokenClassification(TokenClassification, BertPreTrainedModel):
"""Bert token classification model.

Inherited from TokenClassificationBase.
@@ -187,7 +187,7 @@ class BertForSequenceClassification(TokenClassification, BertPreTrainedModel):

def __init__(self, config, model_dir):
if hasattr(config, 'base_model_prefix'):
BertForSequenceClassification.base_model_prefix = config.base_model_prefix
BertForTokenClassification.base_model_prefix = config.base_model_prefix
super().__init__(config, model_dir)

def build_base_model(self):
@@ -218,3 +218,28 @@ class BertForSequenceClassification(TokenClassification, BertPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs)

@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.

@param kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
num_labels: An optional arg to tell the model how many classes to initialize.
Method will call utils.parse_label_mapping if num_labels not supplied.
If num_labels is not found, the model will use the default setting (2 classes).
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""
model_dir = kwargs.get('model_dir')
num_labels = kwargs.get('num_labels')
if num_labels is None:
label2id = parse_label_mapping(model_dir)
if label2id is not None and len(label2id) > 0:
num_labels = len(label2id)

model_args = {} if num_labels is None else {'num_labels': num_labels}
return super(BertPreTrainedModel,
BertForTokenClassification).from_pretrained(
pretrained_model_name_or_path=kwargs.get('model_dir'),
model_dir=kwargs.get('model_dir'),
**model_args)

+ 6
- 1
modelscope/pipelines/nlp/token_classification_pipeline.py View File

@@ -40,7 +40,12 @@ class TokenClassificationPipeline(Pipeline):
sequence_length=kwargs.pop('sequence_length', 128))
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.id2label = getattr(model, 'id2label')
if hasattr(model, 'id2label'):
self.id2label = getattr(model, 'id2label')
else:
model_config = getattr(model, 'config')
self.id2label = getattr(model_config, 'id2label')

assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \
'as a parameter or make sure the preprocessor has the attribute.'



Loading…
Cancel
Save