Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10385225 * fix token classification bugsmaster
@@ -5,7 +5,6 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .modeling_bert import ( | |||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST, | |||
BertForMaskedLM, | |||
BertForMultipleChoice, | |||
BertForNextSentencePrediction, | |||
@@ -20,21 +19,14 @@ if TYPE_CHECKING: | |||
load_tf_weights_in_bert, | |||
) | |||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig | |||
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer | |||
from .tokenization_bert_fast import BertTokenizerFast | |||
from .configuration_bert import BertConfig, BertOnnxConfig | |||
else: | |||
_import_structure = { | |||
'configuration_bert': | |||
['BERT_PRETRAINED_CONFIG_ARCHIVE_MAP', 'BertConfig', 'BertOnnxConfig'], | |||
'tokenization_bert': | |||
['BasicTokenizer', 'BertTokenizer', 'WordpieceTokenizer'], | |||
'configuration_bert': ['BertConfig', 'BertOnnxConfig'], | |||
} | |||
_import_structure['tokenization_bert_fast'] = ['BertTokenizerFast'] | |||
_import_structure['modeling_bert'] = [ | |||
'BERT_PRETRAINED_MODEL_ARCHIVE_LIST', | |||
'BertForMaskedLM', | |||
'BertForMultipleChoice', | |||
'BertForNextSentencePrediction', | |||
@@ -1872,19 +1872,18 @@ class BertForTokenClassification(BertPreTrainedModel): | |||
@add_start_docstrings_to_model_forward( | |||
BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
**kwargs): | |||
r""" | |||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, | |||
*optional*): | |||
@@ -176,7 +176,7 @@ class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): | |||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) | |||
@MODELS.register_module(Tasks.token_classification, module_name=Models.bert) | |||
class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||
class BertForTokenClassification(TokenClassification, BertPreTrainedModel): | |||
"""Bert token classification model. | |||
Inherited from TokenClassificationBase. | |||
@@ -187,7 +187,7 @@ class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||
def __init__(self, config, model_dir): | |||
if hasattr(config, 'base_model_prefix'): | |||
BertForSequenceClassification.base_model_prefix = config.base_model_prefix | |||
BertForTokenClassification.base_model_prefix = config.base_model_prefix | |||
super().__init__(config, model_dir) | |||
def build_base_model(self): | |||
@@ -218,3 +218,28 @@ class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
**kwargs) | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
@param kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
num_labels: An optional arg to tell the model how many classes to initialize. | |||
Method will call utils.parse_label_mapping if num_labels not supplied. | |||
If num_labels is not found, the model will use the default setting (2 classes). | |||
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
model_dir = kwargs.get('model_dir') | |||
num_labels = kwargs.get('num_labels') | |||
if num_labels is None: | |||
label2id = parse_label_mapping(model_dir) | |||
if label2id is not None and len(label2id) > 0: | |||
num_labels = len(label2id) | |||
model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
return super(BertPreTrainedModel, | |||
BertForTokenClassification).from_pretrained( | |||
pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
model_dir=kwargs.get('model_dir'), | |||
**model_args) |
@@ -40,7 +40,12 @@ class TokenClassificationPipeline(Pipeline): | |||
sequence_length=kwargs.pop('sequence_length', 128)) | |||
model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.id2label = getattr(model, 'id2label') | |||
if hasattr(model, 'id2label'): | |||
self.id2label = getattr(model, 'id2label') | |||
else: | |||
model_config = getattr(model, 'config') | |||
self.id2label = getattr(model_config, 'id2label') | |||
assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | |||
'as a parameter or make sure the preprocessor has the attribute.' | |||