Browse Source

1. rename init_embed to embed in models/*; 2. update documents in models/bert.py; 3. update tutorial six.

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
b3718b10dc
12 changed files with 312 additions and 176 deletions
  1. +6
    -0
      docs/source/fastNLP.models.bert.rst
  2. +2
    -1
      docs/source/fastNLP.models.rst
  3. +1
    -1
      docs/source/fastNLP.models.sequence_labeling.rst
  4. +38
    -54
      docs/source/tutorials/tutorial_6_seq_labeling.rst
  5. +8
    -2
      fastNLP/models/__init__.py
  6. +150
    -91
      fastNLP/models/bert.py
  7. +3
    -3
      fastNLP/models/biaffine_parser.py
  8. +3
    -3
      fastNLP/models/cnn_text_classification.py
  9. +5
    -5
      fastNLP/models/snli.py
  10. +12
    -12
      fastNLP/models/star_transformer.py
  11. +82
    -2
      test/models/test_bert.py
  12. +2
    -2
      test/models/test_biaffine_parser.py

+ 6
- 0
docs/source/fastNLP.models.bert.rst View File

@@ -0,0 +1,6 @@
fastNLP.models.bert
===================

.. automodule:: fastNLP.models.bert
:members: BertForSequenceClassification, BertForSentenceMatching, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering


+ 2
- 1
docs/source/fastNLP.models.rst View File

@@ -2,7 +2,7 @@ fastNLP.models
==============

.. automodule:: fastNLP.models
:members: CNNText, SeqLabeling, AdvSeqLabel, ESIM, StarTransEnc, STSeqLabel, STNLICls, STSeqCls, BiaffineParser, GraphParser
:members: CNNText, SeqLabeling, AdvSeqLabel, ESIM, StarTransEnc, STSeqLabel, STNLICls, STSeqCls, BiaffineParser, GraphParser, BertForSequenceClassification, BertForSentenceMatching, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering

子模块
------
@@ -10,6 +10,7 @@ fastNLP.models
.. toctree::
:maxdepth: 1

fastNLP.models.bert
fastNLP.models.biaffine_parser
fastNLP.models.cnn_text_classification
fastNLP.models.sequence_labeling


+ 1
- 1
docs/source/fastNLP.models.sequence_labeling.rst View File

@@ -2,5 +2,5 @@ fastNLP.models.sequence_labeling
================================

.. automodule:: fastNLP.models.sequence_labeling
:members: SeqLabeling, AdvSeqLabel
:members: SeqLabeling, AdvSeqLabel, BiLSTMCRF


+ 38
- 54
docs/source/tutorials/tutorial_6_seq_labeling.rst View File

@@ -3,64 +3,52 @@
=====================

这一部分的内容主要展示如何使用fastNLP 实现序列标注任务。你可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。
在阅读这篇Tutorial前,希望你已经熟悉了fastNLP的基础使用,包括基本数据结构以及数据预处理,embedding的嵌入等,希望你对之前的教程有更进一步的掌握
我们将对CoNLL-03的英文数据集进行处理,展示如何完成命名实体标注任务整个训练的过程。
在阅读这篇Tutorial前,希望你已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建,通过这个小任务的能让你进一步熟悉fastNLP的使用
我们将对基于Weibo的中文社交数据集进行处理,展示如何完成命名实体标注任务整个过程。

载入数据
===================================
fastNLP可以方便地载入各种类型的数据。同时,针对常见的数据集,我们已经预先实现了载入方法,其中包含CoNLL-03数据集。
fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的。通过Loader可以方便地载入各种类型的数据。同时,针对常见的数据集,我们已经预先实现了载入方法,其中包含weibo数据集。
在设计dataloader时,以DataSetLoader为基类,可以改写并应用于其他数据集的载入。

.. code-block:: python

class Conll2003DataLoader(DataSetLoader):
def __init__(self, task:str='ner', encoding_type:str='bioes'):
assert task in ('ner', 'pos', 'chunk')
index = {'ner':3, 'pos':1, 'chunk':2}[task]
#ConllLoader是fastNLP内置的类
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index])
self._tag_converters = None
if task in ('ner', 'chunk'):
#iob和iob2bioes会对tag进行统一,标准化
self._tag_converters = [iob2]
if encoding_type == 'bioes':
self._tag_converters.append(iob2bioes)

def load(self, path: str):
dataset = self._loader.load(path)
def convert_tag_schema(tags):
for converter in self._tag_converters:
tags = converter(tags)
return tags
if self._tag_converters:
#使用apply实现convert_tag_schema函数,实际上也支持匿名函数
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
return dataset

输出数据格式如:

{'raw_words': ['on', 'Friday', ':'] type=list,
'target': ['O', 'O', 'O'] type=list},
from fastNLP.io import WeiboNERLoader
data_bundle = WeiboNERLoader().load()



载入后的数据如 ::

{'dev': DataSet(
{{'raw_chars': ['用', '最', '大', '努', '力', '去', '做''人', '生', '。', '哈', '哈', '哈', '哈', '哈', '哈', '
'target': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',, 'O', 'O', 'O', 'O', 'O', 'O'] type=list})}

{'test': DataSet(
{{'raw_chars': ['感', '恩', '大', '回', '馈'] type=list, 'target': ['O', 'O', 'O', 'O', 'O'] type=list})}

{'train': DataSet(
{'raw_chars': ['国', '安', '老', '球', '迷'] type=list, 'target': ['B-ORG.NAM', 'I-ORG.NAM', 'B-PER.NOM', 'I-PER.NOM', 'I-PER.NOM'] type=list})}



数据处理
----------------------------
我们进一步处理数据。将数据和词表封装在 :class:`~fastNLP.DataBundle` 类中。data是DataBundle的实例。
我们输入模型的数据包括char embedding,以及word embedding。在数据处理部分,我们尝试完成词表的构建。
使用fastNLP中的Vocabulary类来构建词表。
我们进一步处理数据。通过Pipe基类处理Loader载入的数据。 如果你还有印象,应该还能想起,实现自定义数据集的Pipe时,至少要编写process 函数或者process_from_file 函数。前者接受 :class:`~fastNLP.DataBundle` 类的数据,并返回该 :class:`~fastNLP.DataBundle` 。后者接收数据集所在文件夹为参数,读取并处理为 :class:`~fastNLP.DataBundle` 后,通过process 函数处理数据。
这里我们已经实现通过Loader载入数据,并已返回 :class:`~fastNLP.DataBundle` 类的数据。我们编写process 函数以处理Loader载入后的数据。

.. code-block:: python

word_vocab = Vocabulary(min_freq=2)
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT)
word_vocab.index_dataset(*data.datasets.values(),field_name=Const.INPUT, new_field_name=Const.INPUT)
from fastNLP.io import ChineseNERPipe
data_bundle = ChineseNERPipe(encoding_type='bioes', bigram=True).process(data_bundle)

处理后的data对象内部为:
载入后的数据如下 ::

dataset
vocabs
dataset保存了train和test中的数据,并保存为dataset类型
vocab保存了words,raw-words以及target的词表。
{'raw_chars': ['用', '最', '大', '努', '力', '去', '做', '值', '得', '的', '事', '人', '生', '。', '哈', '哈', '哈', '哈', '哈', '哈', '我', '在'] type=list,
'target': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] type=list,
'chars': [97, 71, 34, 422, 104, 72, 144, 628, 66, 3, 158, 2, 9, 647, 485, 196, 2,19] type=list,
'bigrams': [5948, 1950, 34840, 98, 8413, 3961, 34841, 631, 34842, 407, 462, 45, 3 1959, 1619, 3, 3, 3, 3, 3, 2663, 29, 90] type=list,
'seq_len': 30 type=int}

模型构建
--------------------------------
@@ -69,27 +57,23 @@ fastNLP可以方便地载入各种类型的数据。同时,针对常见的数

模型的训练
首先实例化模型,导入所需的char embedding以及word embedding。Embedding的载入可以参考教程。
也可以查看 :mod:`~fastNLP.modules.encoder.embedding` 使用所需的embedding 载入方法。
fastNLP将模型的训练过程封装在了 :class:`~fastnlp.trainer` 类中。
也可以查看 :mod:`~fastNLP.embedding` 使用所需的embedding 载入方法。
fastNLP将模型的训练过程封装在了 :class:`~fastnlp.Trainer` 类中。
根据不同的任务调整trainer中的参数即可。通常,一个trainer实例需要有:指定的训练数据集,模型,优化器,loss函数,评测指标,以及指定训练的epoch数,batch size等参数。

.. code-block:: python

#实例化模型
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type)
#定义优化器
optimizer = Adam(model.parameters(), lr=0.005)
model = CNBiLSTMCRFNER(char_embed, num_classes=len(data_bundle.vocabs['target']), bigram_embed=bigram_embed)
#定义评估指标
Metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type)
#实例化trainer
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, dev_data=data.datasets['test'], batch_size=10, metrics=Metrics,callbacks=callbacks, n_epochs=100)
#开始训练
trainer.train()
Metrics=SpanFPreRecMetric(data_bundle.vocabs['target'], encoding_type='bioes')
#实例化trainer并训练
Trainer(data_bundle.datasets['train'], model, batch_size=20, metrics=Metrics, num_workers=2, dev_data=data_bundle. datasets['dev']).train()

训练中会保存最优的参数配置。
训练的结果如下:

.. code-block:: python
训练的结果如下 ::

Evaluation on DataSet test:
SpanFPreRecMetric: f=0.727661, pre=0.732293, rec=0.723088


+ 8
- 2
fastNLP/models/__init__.py View File

@@ -21,12 +21,18 @@ __all__ = [
"STSeqCls",
"BiaffineParser",
"GraphParser"
"GraphParser",

"BertForSequenceClassification",
"BertForSentenceMatching",
"BertForMultipleChoice",
"BertForTokenClassification",
"BertForQuestionAnswering"
]

from .base_model import BaseModel
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \
BertForTokenClassification
BertForTokenClassification, BertForSentenceMatching
from .biaffine_parser import BiaffineParser, GraphParser
from .cnn_text_classification import CNNText
from .sequence_labeling import SeqLabeling, AdvSeqLabel


+ 150
- 91
fastNLP/models/bert.py View File

@@ -1,9 +1,35 @@
"""undocumented
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0.
"""
fastNLP提供了BERT应用到五个下游任务的模型代码,可以直接调用。这五个任务分别为

- 文本分类任务: :class:`~fastNLP.models.BertForSequenceClassification`
- Matching任务: :class:`~fastNLP.models.BertForSentenceMatching`
- 多选任务: :class:`~fastNLP.models.BertForMultipleChoice`
- 序列标注任务: :class:`~fastNLP.models.BertForTokenClassification`
- 抽取式QA任务: :class:`~fastNLP.models.BertForQuestionAnswering`

每一个模型必须要传入一个名字为 `embed` 的 :class:`fastNLP.embeddings.BertEmbedding` ,这个参数包含了
:class:`fastNLP.modules.encoder.BertModel` ,是下游模型的编码器(encoder)。

除此以外,还需要传入一个数字,这个数字在不同下游任务模型上的意义如下::

下游任务模型 参数名称 含义
BertForSequenceClassification num_labels 文本分类类别数目,默认值为2
BertForSentenceMatching num_labels Matching任务类别数目,默认值为2
BertForMultipleChoice num_choices 多选任务选项数目,默认值为2
BertForTokenClassification num_labels 序列标注标签数目,无默认值
BertForQuestionAnswering num_labels 抽取式QA列数,默认值为2(即第一列为start_span, 第二列为end_span)

最后还可以传入dropout的大小,默认值为0.1。

"""

__all__ = []
__all__ = [
"BertForSequenceClassification",
"BertForSentenceMatching",
"BertForMultipleChoice",
"BertForTokenClassification",
"BertForQuestionAnswering"
]

import warnings

@@ -13,28 +39,40 @@ from torch import nn
from .base_model import BaseModel
from ..core.const import Const
from ..core._logger import logger
from ..modules.encoder import BertModel
from ..modules.encoder.bert import BertConfig, CONFIG_FILE
from ..embeddings.bert_embedding import BertEmbedding
from ..embeddings import BertEmbedding


class BertForSequenceClassification(BaseModel):
"""BERT model for classification.
"""
def __init__(self, init_embed: BertEmbedding, num_labels: int=2):
别名: :class:`fastNLP.models.BertForSequenceClassification`
:class:`fastNLP.models.bert.BertForSequenceClassification`

BERT model for classification.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 文本分类类别数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1):
super(BertForSequenceClassification, self).__init__()

self.num_labels = num_labels
self.bert = init_embed
self.dropout = nn.Dropout(0.1)
self.bert = embed
self.dropout = nn.Dropout(p=dropout)
self.classifier = nn.Linear(self.bert.embedding_dim, num_labels)

if not self.bert.model.include_cls_sep:
warn_msg = "Bert for sequence classification excepts BertEmbedding `include_cls_sep` True, but got False."
self.bert.model.include_cls_sep = True
warn_msg = "Bert for sequence classification excepts BertEmbedding `include_cls_sep` True, " \
"but got False. FastNLP has changed it to True."
logger.warn(warn_msg)
warnings.warn(warn_msg)

def forward(self, words):
"""
:param torch.LongTensor words: [batch_size, seq_len]
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels]
"""
hidden = self.dropout(self.bert(words))
cls_hidden = hidden[:, 0]
logits = self.classifier(cls_hidden)
@@ -42,172 +80,193 @@ class BertForSequenceClassification(BaseModel):
return {Const.OUTPUT: logits}

def predict(self, words):
"""
:param torch.LongTensor words: [batch_size, seq_len]
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size]
"""
logits = self.forward(words)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


class BertForSentenceMatching(BaseModel):
"""
别名: :class:`fastNLP.models.BertForSentenceMatching`
:class:`fastNLP.models.bert.BertForSentenceMatching`

BERT model for sentence matching.

"""BERT model for matching.
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: Matching任务类别数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
def __init__(self, init_embed: BertEmbedding, num_labels: int=2):
def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1):
super(BertForSentenceMatching, self).__init__()
self.num_labels = num_labels
self.bert = init_embed
self.dropout = nn.Dropout(0.1)
self.bert = embed
self.dropout = nn.Dropout(p=dropout)
self.classifier = nn.Linear(self.bert.embedding_dim, num_labels)

if not self.bert.model.include_cls_sep:
error_msg = "Bert for sentence matching excepts BertEmbedding `include_cls_sep` True, but got False."
logger.error(error_msg)
raise RuntimeError(error_msg)
self.bert.model.include_cls_sep = True
warn_msg = "Bert for sentence matching excepts BertEmbedding `include_cls_sep` True, " \
"but got False. FastNLP has changed it to True."
logger.warn(warn_msg)
warnings.warn(warn_msg)

def forward(self, words):
hidden = self.dropout(self.bert(words))
cls_hidden = hidden[:, 0]
"""
:param torch.LongTensor words: [batch_size, seq_len]
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels]
"""
hidden = self.bert(words)
cls_hidden = self.dropout(hidden[:, 0])
logits = self.classifier(cls_hidden)

return {Const.OUTPUT: logits}

def predict(self, words):
"""
:param torch.LongTensor words: [batch_size, seq_len]
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size]
"""
logits = self.forward(words)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


class BertForMultipleChoice(BaseModel):
"""BERT model for multiple choice tasks.
"""
def __init__(self, init_embed: BertEmbedding, num_choices=2):
别名: :class:`fastNLP.models.BertForMultipleChoice`
:class:`fastNLP.models.bert.BertForMultipleChoice`

BERT model for multiple choice.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_choices: 多选任务选项数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
def __init__(self, embed: BertEmbedding, num_choices=2, dropout=0.1):
super(BertForMultipleChoice, self).__init__()

self.num_choices = num_choices
self.bert = init_embed
self.dropout = nn.Dropout(0.1)
self.bert = embed
self.dropout = nn.Dropout(p=dropout)
self.classifier = nn.Linear(self.bert.embedding_dim, 1)
self.include_cls_sep = init_embed.model.include_cls_sep

if not self.bert.model.include_cls_sep:
error_msg = "Bert for multiple choice excepts BertEmbedding `include_cls_sep` True, but got False."
logger.error(error_msg)
raise RuntimeError(error_msg)
self.bert.model.include_cls_sep = True
warn_msg = "Bert for multiple choice excepts BertEmbedding `include_cls_sep` True, " \
"but got False. FastNLP has changed it to True."
logger.warn(warn_msg)
warnings.warn(warn_msg)

def forward(self, words):
"""
:param torch.Tensor words: [batch_size, num_choices, seq_len]
:return: [batch_size, num_labels]
:param torch.LongTensor words: [batch_size, num_choices, seq_len]
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size, num_choices]
"""
batch_size, num_choices, seq_len = words.size()

input_ids = words.view(batch_size * num_choices, seq_len)
hidden = self.bert(input_ids)
pooled_output = hidden[:, 0]
pooled_output = self.dropout(pooled_output)
pooled_output = self.dropout(hidden[:, 0])
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)

return {Const.OUTPUT: reshaped_logits}

def predict(self, words):
"""
:param torch.LongTensor words: [batch_size, num_choices, seq_len]
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size]
"""
logits = self.forward(words)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


class BertForTokenClassification(BaseModel):
"""BERT model for token-level classification.
"""
def __init__(self, init_embed: BertEmbedding, num_labels):
别名: :class:`fastNLP.models.BertForTokenClassification`
:class:`fastNLP.models.bert.BertForTokenClassification`

BERT model for token classification.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 序列标注标签数目,无默认值.
:param float dropout: dropout的大小,默认值为0.1.
"""
def __init__(self, embed: BertEmbedding, num_labels, dropout=0.1):
super(BertForTokenClassification, self).__init__()

self.num_labels = num_labels
self.bert = init_embed
self.dropout = nn.Dropout(0.1)
self.bert = embed
self.dropout = nn.Dropout(p=dropout)
self.classifier = nn.Linear(self.bert.embedding_dim, num_labels)
self.include_cls_sep = init_embed.model.include_cls_sep

if self.include_cls_sep:
warn_msg = "Bert for token classification excepts BertEmbedding `include_cls_sep` False, but got True."
warnings.warn(warn_msg)
if self.bert.model.include_cls_sep:
self.bert.model.include_cls_sep = False
warn_msg = "Bert for token classification excepts BertEmbedding `include_cls_sep` False, " \
"but got True. FastNLP has changed it to False."
logger.warn(warn_msg)
warnings.warn(warn_msg)

def forward(self, words):
"""
:param torch.Tensor words: [batch_size, seq_len]
:return: [batch_size, seq_len, num_labels]
:param torch.LongTensor words: [batch_size, seq_len]
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, seq_len, num_labels]
"""
sequence_output = self.bert(words)
if self.include_cls_sep:
sequence_output = sequence_output[:, 1: -1] # [batch_size, seq_len, embed_dim]
sequence_output = self.bert(words) # [batch_size, seq_len, embed_dim]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

return {Const.OUTPUT: logits}

def predict(self, words):
"""
:param torch.LongTensor words: [batch_size, seq_len]
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size, seq_len]
"""
logits = self.forward(words)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}


class BertForQuestionAnswering(BaseModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`bert_dir`: a dir which contains the bert parameters within file `pytorch_model.bin`
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
bert_dir = 'your-bert-file-dir'
model = BertForQuestionAnswering(config, bert_dir)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, init_embed: BertEmbedding, num_labels=2):
别名: :class:`fastNLP.models.BertForQuestionAnswering`
:class:`fastNLP.models.bert.BertForQuestionAnswering`

BERT model for classification.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 抽取式QA列数,默认值为2(即第一列为start_span, 第二列为end_span).
"""
def __init__(self, embed: BertEmbedding, num_labels=2):
super(BertForQuestionAnswering, self).__init__()

self.bert = init_embed
self.bert = embed
self.num_labels = num_labels
self.qa_outputs = nn.Linear(self.bert.embedding_dim, self.num_labels)

if not self.bert.model.include_cls_sep:
error_msg = "Bert for multiple choice excepts BertEmbedding `include_cls_sep` True, but got False."
logger.error(error_msg)
raise RuntimeError(error_msg)
self.bert.model.include_cls_sep = True
warn_msg = "Bert for question answering excepts BertEmbedding `include_cls_sep` True, " \
"but got False. FastNLP has changed it to True."
logger.warn(warn_msg)
warnings.warn(warn_msg)

def forward(self, words):
"""
:param torch.LongTensor words: [batch_size, seq_len]
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len]
"""
sequence_output = self.bert(words)
logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels]

return {Const.OUTPUTS(i): logits[:, :, i] for i in range(self.num_labels)}

def predict(self, words):
"""
:param torch.LongTensor words: [batch_size, seq_len]
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size]
"""
logits = self.forward(words)
return {Const.OUTPUTS(i): torch.argmax(logits[Const.OUTPUTS(i)], dim=-1) for i in range(self.num_labels)}

+ 3
- 3
fastNLP/models/biaffine_parser.py View File

@@ -245,7 +245,7 @@ class BiaffineParser(GraphParser):
Biaffine Dependency Parser 实现.
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ .

:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param pos_vocab_size: part-of-speech 词典大小
@@ -262,7 +262,7 @@ class BiaffineParser(GraphParser):
"""
def __init__(self,
init_embed,
embed,
pos_vocab_size,
pos_emb_dim,
num_label,
@@ -276,7 +276,7 @@ class BiaffineParser(GraphParser):
super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size
word_hid_dim = pos_hid_dim = rnn_hidden_size
self.word_embedding = get_embeddings(init_embed)
self.word_embedding = get_embeddings(embed)
word_emb_dim = self.word_embedding.embedding_dim
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim)


+ 3
- 3
fastNLP/models/cnn_text_classification.py View File

@@ -23,7 +23,7 @@ class CNNText(torch.nn.Module):
使用CNN进行文本分类的模型
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.'
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int),
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
:param int num_classes: 一共有多少类
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
@@ -31,7 +31,7 @@ class CNNText(torch.nn.Module):
:param float dropout: Dropout的大小
"""

def __init__(self, init_embed,
def __init__(self, embed,
num_classes,
kernel_nums=(30, 40, 50),
kernel_sizes=(1, 3, 5),
@@ -39,7 +39,7 @@ class CNNText(torch.nn.Module):
super(CNNText, self).__init__()

# no support for pre-trained embedding currently
self.embed = embedding.Embedding(init_embed)
self.embed = embedding.Embedding(embed)
self.conv_pool = encoder.ConvMaxpool(
in_channels=self.embed.embedding_dim,
out_channels=kernel_nums,


+ 5
- 5
fastNLP/models/snli.py View File

@@ -24,21 +24,21 @@ class ESIM(BaseModel):
ESIM model的一个PyTorch实现
论文参见: https://arxiv.org/pdf/1609.06038.pdf

:param init_embedding: 初始化的Embedding
:param embed: 初始化的Embedding
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度
:param int num_labels: 目标标签种类数量,默认值为3
:param float dropout_rate: dropout的比率,默认值为0.3
:param float dropout_embed: 对Embedding的dropout比率,默认值为0.1
"""

def __init__(self, init_embedding, hidden_size=None, num_labels=3, dropout_rate=0.3,
def __init__(self, embed, hidden_size=None, num_labels=3, dropout_rate=0.3,
dropout_embed=0.1):
super(ESIM, self).__init__()

if isinstance(init_embedding, TokenEmbedding) or isinstance(init_embedding, Embedding):
self.embedding = init_embedding
if isinstance(embed, TokenEmbedding) or isinstance(embed, Embedding):
self.embedding = embed
else:
self.embedding = Embedding(init_embedding)
self.embedding = Embedding(embed)
self.dropout_embed = EmbedDropout(p=dropout_embed)
if hidden_size is None:
hidden_size = self.embedding.embed_size


+ 12
- 12
fastNLP/models/star_transformer.py View File

@@ -23,7 +23,7 @@ class StarTransEnc(nn.Module):

带word embedding的Star-Transformer Encoder

:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param hidden_size: 模型中特征维度.
@@ -35,7 +35,7 @@ class StarTransEnc(nn.Module):
:param dropout: 模型除词嵌入外的dropout概率.
"""

def __init__(self, init_embed,
def __init__(self, embed,
hidden_size,
num_layers,
num_head,
@@ -44,7 +44,7 @@ class StarTransEnc(nn.Module):
emb_dropout,
dropout):
super(StarTransEnc, self).__init__()
self.embedding = get_embeddings(init_embed)
self.embedding = get_embeddings(embed)
emb_dim = self.embedding.embedding_dim
self.emb_fc = nn.Linear(emb_dim, hidden_size)
# self.emb_drop = nn.Dropout(emb_dropout)
@@ -108,7 +108,7 @@ class STSeqLabel(nn.Module):

用于序列标注的Star-Transformer模型

:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
@@ -122,7 +122,7 @@ class STSeqLabel(nn.Module):
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""

def __init__(self, init_embed, num_cls,
def __init__(self, embed, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
@@ -132,7 +132,7 @@ class STSeqLabel(nn.Module):
emb_dropout=0.1,
dropout=0.1, ):
super(STSeqLabel, self).__init__()
self.enc = StarTransEnc(init_embed=init_embed,
self.enc = StarTransEnc(embed=embed,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,
@@ -173,7 +173,7 @@ class STSeqCls(nn.Module):

用于分类任务的Star-Transformer

:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
@@ -187,7 +187,7 @@ class STSeqCls(nn.Module):
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""

def __init__(self, init_embed, num_cls,
def __init__(self, embed, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
@@ -197,7 +197,7 @@ class STSeqCls(nn.Module):
emb_dropout=0.1,
dropout=0.1, ):
super(STSeqCls, self).__init__()
self.enc = StarTransEnc(init_embed=init_embed,
self.enc = StarTransEnc(embed=embed,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,
@@ -238,7 +238,7 @@ class STNLICls(nn.Module):
用于自然语言推断(NLI)的Star-Transformer

:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
@@ -252,7 +252,7 @@ class STNLICls(nn.Module):
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""

def __init__(self, init_embed, num_cls,
def __init__(self, embed, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
@@ -262,7 +262,7 @@ class STNLICls(nn.Module):
emb_dropout=0.1,
dropout=0.1, ):
super(STNLICls, self).__init__()
self.enc = StarTransEnc(init_embed=init_embed,
self.enc = StarTransEnc(embed=embed,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,


+ 82
- 2
test/models/test_bert.py View File

@@ -23,10 +23,25 @@ class TestBert(unittest.TestCase):
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))

pred = model.predict(input_ids)
pred = model(input_ids)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))

def test_bert_1_w(self):
vocab = Vocabulary().add_word_lst("this is a test .".split())
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
include_cls_sep=False)

with self.assertWarns(Warning):
model = BertForSequenceClassification(embed, 2)

input_ids = torch.LongTensor([[1, 2, 3], [5, 6, 0]])

pred = model.predict(input_ids)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))

def test_bert_2(self):

@@ -44,6 +59,23 @@ class TestBert(unittest.TestCase):
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))

def test_bert_2_w(self):

vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
include_cls_sep=False)

with self.assertWarns(Warning):
model = BertForMultipleChoice(embed, 2)

input_ids = torch.LongTensor([[[2, 6, 7], [1, 6, 5]]])
print(input_ids.size())

pred = model.predict(input_ids)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1,))

def test_bert_3(self):

vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
@@ -58,6 +90,22 @@ class TestBert(unittest.TestCase):
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))

def test_bert_3_w(self):

vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
include_cls_sep=True)

with self.assertWarns(Warning):
model = BertForTokenClassification(embed, 7)

input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])

pred = model.predict(input_ids)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3))

def test_bert_4(self):

vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
@@ -79,6 +127,22 @@ class TestBert(unittest.TestCase):
self.assertTrue(isinstance(pred, dict))
self.assertEqual(len(pred), 7)

def test_bert_4_w(self):

vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
include_cls_sep=False)

with self.assertWarns(Warning):
model = BertForQuestionAnswering(embed)

input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])

pred = model.predict(input_ids)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUTS(1) in pred)
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2,))

def test_bert_5(self):

vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
@@ -93,3 +157,19 @@ class TestBert(unittest.TestCase):
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))

def test_bert_5_w(self):

vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
include_cls_sep=False)

with self.assertWarns(Warning):
model = BertForSentenceMatching(embed)

input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])

pred = model.predict(input_ids)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))


+ 2
- 2
test/models/test_biaffine_parser.py View File

@@ -27,7 +27,7 @@ def prepare_parser_data():

class TestBiaffineParser(unittest.TestCase):
def test_train(self):
model = BiaffineParser(init_embed=(VOCAB_SIZE, 10),
model = BiaffineParser(embed=(VOCAB_SIZE, 10),
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
rnn_hidden_size=10,
arc_mlp_size=10,
@@ -37,7 +37,7 @@ class TestBiaffineParser(unittest.TestCase):
RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric())

def test_train2(self):
model = BiaffineParser(init_embed=(VOCAB_SIZE, 10),
model = BiaffineParser(embed=(VOCAB_SIZE, 10),
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
rnn_hidden_size=16,
arc_mlp_size=10,


Loading…
Cancel
Save