|
|
@@ -0,0 +1,82 @@ |
|
|
|
============================== |
|
|
|
BertEmbedding的各种用法 |
|
|
|
============================== |
|
|
|
|
|
|
|
fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础,是一个使用BERT对words进行编码的Embedding。 |
|
|
|
|
|
|
|
使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。 |
|
|
|
|
|
|
|
预训练好的Embedding参数及数据集的介绍和自动下载功能见 :doc:`/tutorials/tutorial_3_embedding` 和 |
|
|
|
:doc:`/tutorials/tutorial_4_load_dataset` |
|
|
|
|
|
|
|
1. BERT for Squence Classification |
|
|
|
---------------------------------- |
|
|
|
|
|
|
|
在文本分类任务中,我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。 |
|
|
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
|
|
import warnings |
|
|
|
import torch |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
# 载入数据集 |
|
|
|
from fastNLP.io import SSTPipe |
|
|
|
data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file() |
|
|
|
data_bundle |
|
|
|
|
|
|
|
# 载入BertEmbedding |
|
|
|
from fastNLP.embeddings import BertEmbedding |
|
|
|
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True) |
|
|
|
|
|
|
|
# 载入模型 |
|
|
|
from fastNLP.models import BertForSequenceClassification |
|
|
|
model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target'))) |
|
|
|
|
|
|
|
# 训练模型 |
|
|
|
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam |
|
|
|
trainer = Trainer(data_bundle.get_dataset('train'), model, |
|
|
|
optimizer=Adam(model_params=model.parameters(), lr=2e-5), |
|
|
|
loss=CrossEntropyLoss(), device=[0], |
|
|
|
batch_size=64, dev_data=data_bundle.get_dataset('dev'), |
|
|
|
metrics=AccuracyMetric(), n_epochs=2, print_every=1) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 测试结果并删除模型 |
|
|
|
from fastNLP import Tester |
|
|
|
tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric()) |
|
|
|
tester.test() |
|
|
|
|
|
|
|
2. BERT for Sentence Matching |
|
|
|
----------------------------- |
|
|
|
|
|
|
|
在Matching任务中,我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。 |
|
|
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
|
|
# 载入数据集 |
|
|
|
from fastNLP.io import RTEBertPipe |
|
|
|
data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file() |
|
|
|
|
|
|
|
# 载入BertEmbedding |
|
|
|
from fastNLP.embeddings import BertEmbedding |
|
|
|
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True) |
|
|
|
|
|
|
|
|
|
|
|
# 载入模型 |
|
|
|
from fastNLP.models import BertForSentenceMatching |
|
|
|
model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target'))) |
|
|
|
|
|
|
|
# 训练模型 |
|
|
|
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam |
|
|
|
trainer = Trainer(data_bundle.get_dataset('train'), model, |
|
|
|
optimizer=Adam(model_params=model.parameters(), lr=2e-5), |
|
|
|
loss=CrossEntropyLoss(), device=[0], |
|
|
|
batch_size=16, dev_data=data_bundle.get_dataset('dev'), |
|
|
|
metrics=AccuracyMetric(), n_epochs=2, print_every=1) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
|