@@ -155,37 +155,3 @@ fastNLP中field的命名习惯 | |||
- **chars**: 表示已经切分为单独的汉字的序列。例如["这", "是", "一", "个", "示", "例", "。"]。但由于神经网络不能识别汉字,所以一般该列会被转为int形式,如[3, 4, 5, 6, ...]。 | |||
- **target**: 表示目标值。分类场景下,只有一个值;序列标注场景下是一个序列 | |||
- **seq_len**: 表示输入序列的长度 | |||
----------------------------- | |||
DataSet与pad | |||
----------------------------- | |||
.. todo:: | |||
这一段移动到datasetiter那里 | |||
在fastNLP里,pad是与一个 :mod:`~fastNLP.core.field` 绑定的。即不同的 :mod:`~fastNLP.core.field` 可以使用不同的pad方式,比如在英文任务中word需要的pad和 | |||
character的pad方式往往是不同的。fastNLP是通过一个叫做 :class:`~fastNLP.Padder` 的子类来完成的。 | |||
默认情况下,所有field使用 :class:`~fastNLP.AutoPadder` | |||
。可以通过使用以下方式设置Padder(如果将padder设置为None,则该field不会进行pad操作)。 | |||
大多数情况下直接使用 :class:`~fastNLP.AutoPadder` 就可以了。 | |||
如果 :class:`~fastNLP.AutoPadder` 或 :class:`~fastNLP.EngChar2DPadder` 无法满足需求, | |||
也可以自己写一个 :class:`~fastNLP.Padder` 。 | |||
.. code-block:: python | |||
from fastNLP import DataSet | |||
from fastNLP import EngChar2DPadder | |||
import random | |||
dataset = DataSet() | |||
max_chars, max_words, sent_num = 5, 10, 20 | |||
contents = [[ | |||
[random.randint(1, 27) for _ in range(random.randint(1, max_chars))] | |||
for _ in range(random.randint(1, max_words)) | |||
] for _ in range(sent_num)] | |||
# 初始化时传入 | |||
dataset.add_field('chars', contents, padder=EngChar2DPadder()) | |||
# 直接设置 | |||
dataset.set_padder('chars', EngChar2DPadder()) | |||
# 也可以设置pad的value | |||
dataset.set_pad_val('chars', -1) |
@@ -13,7 +13,6 @@ | |||
- `Part V: 不同格式类型的基础Loader`_ | |||
------------------------------------ | |||
Part I: 数据集容器DataBundle | |||
------------------------------------ | |||
@@ -24,7 +23,6 @@ Part I: 数据集容器DataBundle | |||
:class:`~fastNLP.io.DataBundle` 在fastNLP中主要在各个 :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` 中被使用。 | |||
下面我们先介绍一下 :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` 。 | |||
------------------------------------- | |||
Part II: 加载的各种数据集的Loader | |||
------------------------------------- | |||
@@ -74,7 +72,6 @@ Part II: 加载的各种数据集的Loader | |||
| 中共中央 总书记 、 国家 主席 江 泽民 | | |||
+--------------------------------------------------------------------------------------+ | |||
------------------------------------------ | |||
Part III: 使用Pipe对数据集进行预处理 | |||
------------------------------------------ | |||
通过 :class:`~fastNLP.io.Loader` 可以将文本数据读入,但并不能直接被神经网络使用,还需要进行一定的预处理。 | |||
@@ -84,8 +81,8 @@ Part III: 使用Pipe对数据集进行预处理 | |||
raw_chars进行tokenize以切分成不同的词或字; (2) 再建立词或字的 :class:`~fastNLP.Vocabulary` , 并将词或字转换为index; (3)将target | |||
列建立词表并将target列转为index; | |||
所有的Pipe都可通过其文档查看该Pipe支持处理的 :class:`~fastNLP.DataSet` 以及返回的 :class:`~fastNLP.io.DataSet` 中的field的情况; | |||
如 :class:`~fastNLP.io.` | |||
所有的Pipe都可通过其文档查看该Pipe支持处理的 :class:`~fastNLP.DataSet` 以及返回的 :class:`~fastNLP.io.DataBundle` 中的Vocabulary的情况; | |||
如 :class:`~fastNLP.io.OntoNotesNERPipe` | |||
各种数据集的Pipe当中,都包含了以下的两个函数: | |||
@@ -144,14 +141,14 @@ raw_chars进行tokenize以切分成不同的词或字; (2) 再建立词或字的 | |||
Vocabulary(['B', 'E', 'S', 'M']...) | |||
------------------------------------------ | |||
Part IV: fastNLP封装好的Loader和Pipe | |||
------------------------------------------ | |||
fastNLP封装了多种任务/数据集的 :class:`~fastNLP.io.Loader` 和 :class:`~fastNLP.io.Pipe` 并提供自动下载功能,具体参见文档 | |||
`数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_ | |||
-------------------------------------------------------- | |||
Part V: 不同格式类型的基础Loader | |||
-------------------------------------------------------- | |||
@@ -37,12 +37,12 @@ DataSetIter初探之前的内容与 :doc:`/tutorials/tutorial_5_loss_optimizer` | |||
输出数据如下:: | |||
In total 3 datasets: | |||
test has 1821 instances. | |||
train has 67349 instances. | |||
dev has 872 instances. | |||
test has 1821 instances. | |||
train has 67349 instances. | |||
dev has 872 instances. | |||
In total 2 vocabs: | |||
words has 16293 entries. | |||
target has 2 entries. | |||
words has 16293 entries. | |||
target has 2 entries. | |||
+-------------------------------------------+--------+--------------------------------------+---------+ | |||
| raw_words | target | words | seq_len | | |||
@@ -59,9 +59,9 @@ DataSetIter初探之前的内容与 :doc:`/tutorials/tutorial_5_loss_optimizer` | |||
.. code-block:: python | |||
train_data = databundle.datasets['train'] | |||
train_data = databundle.get_dataset('train') | |||
train_data, test_data = train_data.split(0.015) | |||
dev_data = databundle.datasets['dev'] | |||
dev_data = databundle.get_dataset('dev') | |||
print(len(train_data),len(dev_data),len(test_data)) | |||
输出结果为:: | |||
@@ -69,7 +69,10 @@ DataSetIter初探之前的内容与 :doc:`/tutorials/tutorial_5_loss_optimizer` | |||
66339 872 1010 | |||
数据集 :meth:`~fastNLP.DataSet.set_input` 和 :meth:`~fastNLP.DataSet.set_target` 函数 | |||
:class:`~fastNLP.io.SST2Pipe` 类的 :meth:`~fastNLP.io.SST2Pipe.process_from_file` 方法在预处理过程中还将训练、测试、验证集的 `words` 、`seq_len` :mod:`~fastNLP.core.field` 设定为input,同时将`target` :mod:`~fastNLP.core.field` 设定为target。我们可以通过 :class:`~fastNLP.core.Dataset` 类的 :meth:`~fastNLP.core.Dataset.print_field_meta` 方法查看各个 :mod:`~fastNLP.core.field` 的设定情况,代码如下: | |||
:class:`~fastNLP.io.SST2Pipe` 类的 :meth:`~fastNLP.io.SST2Pipe.process_from_file` 方法在预处理过程中还将训练、测试、验证集 | |||
的 `words` 、`seq_len` :mod:`~fastNLP.core.field` 设定为input,同时将`target` :mod:`~fastNLP.core.field` 设定为target。 | |||
我们可以通过 :class:`~fastNLP.core.Dataset` 类的 :meth:`~fastNLP.core.Dataset.print_field_meta` 方法查看各个 | |||
:mod:`~fastNLP.core.field` 的设定情况,代码如下: | |||
.. code-block:: python | |||
@@ -86,9 +89,13 @@ DataSetIter初探之前的内容与 :doc:`/tutorials/tutorial_5_loss_optimizer` | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
其中is_input和is_target分别表示是否为input和target。ignore_type为true时指使用 :class:`~fastNLP.DataSetIter` 取出batch数据时fastNLP不会进行自动padding,pad_value指对应 :mod:`~fastNLP.core.field` padding所用的值,这两者只有当 :mod:`~fastNLP.core.field` 设定为input或者target的时候才有存在的意义。 | |||
其中is_input和is_target分别表示是否为input和target。ignore_type为true时指使用 :class:`~fastNLP.DataSetIter` 取出batch数 | |||
据时fastNLP不会进行自动padding,pad_value指对应 :mod:`~fastNLP.core.field` padding所用的值,这两者只有当 | |||
:mod:`~fastNLP.core.field` 设定为input或者target的时候才有存在的意义。 | |||
is_input为true的 :mod:`~fastNLP.core.field` 在 :class:`~fastNLP.DataSetIter` 迭代取出的 batch_x 中,而 is_target为true的 :mod:`~fastNLP.core.field` 在 :class:`~fastNLP.DataSetIter` 迭代取出的 batch_y 中。具体分析见下面DataSetIter的介绍过程。 | |||
is_input为true的 :mod:`~fastNLP.core.field` 在 :class:`~fastNLP.DataSetIter` 迭代取出的 batch_x 中, | |||
而 is_target为true的 :mod:`~fastNLP.core.field` 在 :class:`~fastNLP.DataSetIter` 迭代取出的 batch_y 中。 | |||
具体分析见下面DataSetIter的介绍过程。 | |||
评价指标 | |||
@@ -111,6 +118,7 @@ DataSetIter初探之前的内容与 :doc:`/tutorials/tutorial_5_loss_optimizer` | |||
-------------------------- | |||
DataSetIter初探 | |||
-------------------------- | |||
DataSetIter | |||
fastNLP定义的 :class:`~fastNLP.DataSetIter` 类,用于定义一个batch,并实现batch的多种功能,在初始化时传入的参数有: | |||
@@ -2,97 +2,123 @@ | |||
快速实现序列标注模型 | |||
===================== | |||
这一部分的内容主要展示如何使用fastNLP 实现序列标注任务。你可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。 | |||
在阅读这篇Tutorial前,希望你已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建,通过这个小任务的能让你进一步熟悉fastNLP的使用。 | |||
我们将对基于Weibo的中文社交数据集进行处理,展示如何完成命名实体标注任务的整个过程。 | |||
这一部分的内容主要展示如何使用fastNLP实现序列标注任务。您可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。 | |||
在阅读这篇Tutorial前,希望您已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建,通过这个小任务的能让您进一步熟悉fastNLP的使用。 | |||
命名实体识别(name entity recognition, NER) | |||
------------------------------------------ | |||
命名实体识别任务是从文本中抽取出具有特殊意义或者指代性非常强的实体,通常包括人名、地名、机构名和时间等。 | |||
如下面的例子中 | |||
我来自复旦大学。 | |||
其中“复旦大学”就是一个机构名,命名实体识别就是要从中识别出“复旦大学”这四个字是一个整体,且属于机构名这个类别。这个问题在实际做的时候会被 | |||
转换为序列标注问题 | |||
针对"我来自复旦大学"这句话,我们的预测目标将是[O, O, O, B-ORG, I-ORG, I-ORG, I-ORG],其中O表示out,即不是一个实体,B-ORG是ORG( | |||
organization的缩写)这个类别的开头(Begin),I-ORG是ORG类别的中间(Inside)。 | |||
在本tutorial中我们将通过fastNLP尝试写出一个能够执行以上任务的模型。 | |||
载入数据 | |||
=================================== | |||
fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的。通过Loader可以方便地载入各种类型的数据。同时,针对常见的数据集,我们已经预先实现了载入方法,其中包含weibo数据集。 | |||
在设计dataloader时,以DataSetLoader为基类,可以改写并应用于其他数据集的载入。 | |||
------------------------------------------ | |||
fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您可以通过 :doc:`使用Loader和Pipe处理数据 </tutorials/tutorial_4_load_dataset>` | |||
了解如何使用fastNLP提供的数据加载函数。下面我们以微博命名实体任务来演示一下在fastNLP进行序列标注任务。 | |||
.. code-block:: python | |||
from fastNLP.io import WeiboNERLoader | |||
data_bundle = WeiboNERLoader().load() | |||
from fastNLP.io import WeiboNERPipe | |||
data_bundle = WeiboNERPipe().process_from_file() | |||
print(data_bundle.get_dataset('train')[:2]) | |||
打印的数据如下 :: | |||
+-------------------------------------------------+------------------------------------------+------------------------------------------+---------+ | |||
| raw_chars | target | chars | seq_len | | |||
+-------------------------------------------------+------------------------------------------+------------------------------------------+---------+ | |||
| ['一', '节', '课', '的', '时', '间', '真', '... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, ... | [8, 211, 775, 3, 49, 245, 89, 26, 101... | 16 | | |||
| ['回', '复', '支', '持', ',', '赞', '成', '... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [116, 480, 127, 109, 2, 446, 134, 2, ... | 59 | | |||
+-------------------------------------------------+------------------------------------------+------------------------------------------+---------+ | |||
载入后的数据如 :: | |||
{'dev': DataSet( | |||
{{'raw_chars': ['用', '最', '大', '努', '力', '去', '做''人', '生', '。', '哈', '哈', '哈', '哈', '哈', '哈', ' | |||
'target': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',, 'O', 'O', 'O', 'O', 'O', 'O'] type=list})} | |||
模型构建 | |||
-------------------------------- | |||
{'test': DataSet( | |||
{{'raw_chars': ['感', '恩', '大', '回', '馈'] type=list, 'target': ['O', 'O', 'O', 'O', 'O'] type=list})} | |||
首先选择需要使用的Embedding类型。关于Embedding的相关说明可以参见 :doc:`使用Embedding模块将文本转成向量 </tutorials/tutorial_3_embedding>` 。 | |||
在这里我们使用通过word2vec预训练的中文汉字embedding。 | |||
{'train': DataSet( | |||
{'raw_chars': ['国', '安', '老', '球', '迷'] type=list, 'target': ['B-ORG.NAM', 'I-ORG.NAM', 'B-PER.NOM', 'I-PER.NOM', 'I-PER.NOM'] type=list})} | |||
.. code-block:: python | |||
from fastNLP.embeddings import StaticEmbedding | |||
embed = StaticEmbedding(vocab=data_bundle.get_vocab('chars'), model_dir_or_name='cn-char-fastnlp-100d') | |||
数据处理 | |||
---------------------------- | |||
我们进一步处理数据。通过Pipe基类处理Loader载入的数据。 如果你还有印象,应该还能想起,实现自定义数据集的Pipe时,至少要编写process 函数或者process_from_file 函数。前者接受 :class:`~fastNLP.DataBundle` 类的数据,并返回该 :class:`~fastNLP.DataBundle` 。后者接收数据集所在文件夹为参数,读取并处理为 :class:`~fastNLP.DataBundle` 后,通过process 函数处理数据。 | |||
这里我们已经实现通过Loader载入数据,并已返回 :class:`~fastNLP.DataBundle` 类的数据。我们编写process 函数以处理Loader载入后的数据。 | |||
选择好Embedding之后,我们可以使用fastNLP中自带的 :class:`fastNLP.models.BiLSTMCRF` 作为模型。 | |||
.. code-block:: python | |||
from fastNLP.io import ChineseNERPipe | |||
data_bundle = ChineseNERPipe(encoding_type='bioes', bigram=True).process(data_bundle) | |||
from fastNLP.models import BiLSTMCRF | |||
载入后的数据如下 :: | |||
data_bundle.rename_field('chars', 'words') # 这是由于BiLSTMCRF模型的forward函数接受的words,而不是chars,所以需要把这一列重新命名 | |||
model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5, | |||
target_vocab=data_bundle.get_vocab('target')) | |||
{'raw_chars': ['用', '最', '大', '努', '力', '去', '做', '值', '得', '的', '事', '人', '生', '。', '哈', '哈', '哈', '哈', '哈', '哈', '我', '在'] type=list, | |||
'target': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] type=list, | |||
'chars': [97, 71, 34, 422, 104, 72, 144, 628, 66, 3, 158, 2, 9, 647, 485, 196, 2,19] type=list, | |||
'bigrams': [5948, 1950, 34840, 98, 8413, 3961, 34841, 631, 34842, 407, 462, 45, 3 1959, 1619, 3, 3, 3, 3, 3, 2663, 29, 90] type=list, | |||
'seq_len': 30 type=int} | |||
下面我们选择用来评估模型的metric,以及优化用到的优化函数。 | |||
模型构建 | |||
-------------------------------- | |||
我们使用CNN-BILSTM-CRF模型完成这一任务。在网络构建方面,fastNLP的网络定义继承pytorch的 :class:`nn.Module` 类。 | |||
自己可以按照pytorch的方式定义网络。需要注意的是命名。fastNLP的标准命名位于 :class:`~fastNLP.Const` 类。 | |||
.. code-block:: python | |||
模型的训练 | |||
首先实例化模型,导入所需的char embedding以及word embedding。Embedding的载入可以参考教程。 | |||
也可以查看 :mod:`~fastNLP.embedding` 使用所需的embedding 载入方法。 | |||
fastNLP将模型的训练过程封装在了 :class:`~fastnlp.Trainer` 类中。 | |||
根据不同的任务调整trainer中的参数即可。通常,一个trainer实例需要有:指定的训练数据集,模型,优化器,loss函数,评测指标,以及指定训练的epoch数,batch size等参数。 | |||
from fastNLP import SpanFPreRecMetric | |||
from torch.optim import Adam | |||
from fastNLP import LossInForward | |||
metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target')) | |||
optimizer = Adam(model.parameters(), lr=1e-4) | |||
loss = LossInForward() | |||
使用Trainer进行训练 | |||
.. code-block:: python | |||
#实例化模型 | |||
model = CNBiLSTMCRFNER(char_embed, num_classes=len(data_bundle.vocabs['target']), bigram_embed=bigram_embed) | |||
#定义评估指标 | |||
Metrics=SpanFPreRecMetric(data_bundle.vocabs['target'], encoding_type='bioes') | |||
#实例化trainer并训练 | |||
Trainer(data_bundle.datasets['train'], model, batch_size=20, metrics=Metrics, num_workers=2, dev_data=data_bundle. datasets['dev']).train() | |||
训练中会保存最优的参数配置。 | |||
训练的结果如下 :: | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.727661, pre=0.732293, rec=0.723088 | |||
Evaluation at Epoch 1/100. Step:1405/140500. SpanFPreRecMetric: f=0.727661, pre=0.732293, rec=0.723088 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.784307, pre=0.779371, rec=0.789306 | |||
Evaluation at Epoch 2/100. Step:2810/140500. SpanFPreRecMetric: f=0.784307, pre=0.779371, rec=0.789306 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.810068, pre=0.811003, rec=0.809136 | |||
Evaluation at Epoch 3/100. Step:4215/140500. SpanFPreRecMetric: f=0.810068, pre=0.811003, rec=0.809136 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.829592, pre=0.84153, rec=0.817989 | |||
Evaluation at Epoch 4/100. Step:5620/140500. SpanFPreRecMetric: f=0.829592, pre=0.84153, rec=0.817989 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.828789, pre=0.837096, rec=0.820644 | |||
Evaluation at Epoch 5/100. Step:7025/140500. SpanFPreRecMetric: f=0.828789, pre=0.837096, rec=0.820644 | |||
from fastNLP import Trainer | |||
import torch | |||
device= 0 if torch.cuda.is_available() else 'cpu' | |||
trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer, | |||
dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device) | |||
trainer.train() | |||
训练过程输出为:: | |||
input fields after batch(if batch size is 2): | |||
target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) | |||
seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) | |||
words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) | |||
target fields after batch(if batch size is 2): | |||
target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) | |||
seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) | |||
training epochs started 2019-09-25-10-43-09 | |||
Evaluate data in 0.62 seconds! | |||
Evaluation on dev at Epoch 1/10. Step:43/430: | |||
SpanFPreRecMetric: f=0.070352, pre=0.100962, rec=0.053985 | |||
... | |||
Evaluate data in 0.61 seconds! | |||
Evaluation on dev at Epoch 10/10. Step:430/430: | |||
SpanFPreRecMetric: f=0.51223, pre=0.581699, rec=0.457584 | |||
In Epoch:7/Step:301, got best dev performance: | |||
SpanFPreRecMetric: f=0.515528, pre=0.65098, rec=0.426735 | |||
Reloaded the best model. | |||
训练结束之后过,可以通过 :class:`fastNLP.Tester`测试其在测试集上的性能 | |||
.. code-block::python | |||
from fastNLP import Tester | |||
tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric) | |||
tester.test() |
@@ -495,6 +495,7 @@ class Trainer(object): | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||
model.train() | |||
self.model = _move_model_to_device(model, device=device) | |||
if _model_contains_inner_module(self.model): | |||
self._forward_func = self.model.module.forward | |||
@@ -12,6 +12,7 @@ __all__ = [ | |||
"SeqLabeling", | |||
"AdvSeqLabel", | |||
"BiLSTMCRF", | |||
"ESIM", | |||
@@ -35,7 +36,7 @@ from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequen | |||
BertForTokenClassification, BertForSentenceMatching | |||
from .biaffine_parser import BiaffineParser, GraphParser | |||
from .cnn_text_classification import CNNText | |||
from .sequence_labeling import SeqLabeling, AdvSeqLabel | |||
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
from .snli import ESIM | |||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
@@ -27,7 +27,7 @@ class BiLSTMCRF(BaseModel): | |||
""" | |||
def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, | |||
target_vocab=None, encoding_type=None): | |||
target_vocab=None): | |||
""" | |||
:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100) | |||
@@ -35,8 +35,7 @@ class BiLSTMCRF(BaseModel): | |||
:param num_layers: BiLSTM的层数 | |||
:param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向) | |||
:param dropout: dropout的概率,0为不dropout | |||
:param target_vocab: Vocabulary对象,target与index的对应关系 | |||
:param encoding_type: encoding的类型,支持'bioes', 'bmes', 'bio', 'bmeso'等 | |||
:param target_vocab: Vocabulary对象,target与index的对应关系。如果传入该值,将自动避免非法的解码序列。 | |||
""" | |||
super().__init__() | |||
self.embed = get_embeddings(embed) | |||
@@ -52,8 +51,9 @@ class BiLSTMCRF(BaseModel): | |||
self.fc = nn.Linear(hidden_size*2, num_classes) | |||
trans = None | |||
if target_vocab is not None and encoding_type is not None: | |||
trans = allowed_transitions(target_vocab.idx2word, encoding_type=encoding_type, include_start_end=True) | |||
if target_vocab is not None: | |||
assert len(target_vocab)==num_classes, "The number of classes should be same with the length of target vocabulary." | |||
trans = allowed_transitions(target_vocab.idx2word, include_start_end=True) | |||
self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans) | |||