@@ -92,7 +92,7 @@ http://docutils.sf.net/ 孤立的网址会自动生成链接 | |||
各种连接 | |||
=========== | |||
:doc:`/user/with_fitlog.rst` | |||
:doc:`/user/with_fitlog` | |||
:mod:`~fastNLP.core.batch` | |||
@@ -438,7 +438,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
class SpanFPreRecMetric(MetricBase): | |||
""" | |||
r""" | |||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
@@ -476,7 +476,7 @@ class SpanFPreRecMetric(MetricBase): | |||
label的f1, pre, rec | |||
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | |||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param float beta: f_beta分数, :math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
@@ -699,16 +699,16 @@ def _pred_topk(y_prob, k=1): | |||
class SQuADMetric(MetricBase): | |||
""" | |||
r""" | |||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||
SQuAD数据集metric | |||
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | |||
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` | |||
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` | |||
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` | |||
:param float beta: f_beta分数, :math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||
:param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` | |||
:param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||
@@ -1,13 +1,41 @@ | |||
import torch | |||
import torch.nn as nn | |||
from fastNLP.core.const import Const | |||
from fastNLP.models import BaseModel | |||
from fastNLP.modules.encoder.bert import BertModel | |||
class BertForSNLI(BaseModel): | |||
class BertForNLI(BaseModel): | |||
# TODO: still in progress | |||
def __init(self): | |||
super(BertForSNLI, self).__init__() | |||
def __init__(self, class_num=3, bert_dir=None): | |||
super(BertForNLI, self).__init__() | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
self.bert = BertModel() | |||
hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1) | |||
self.classifier = nn.Linear(hidden_size, class_num) | |||
def forward(self, words, seq_len1, seq_len2, target=None): | |||
""" | |||
:param torch.Tensor words: [batch_size, seq_len] input_ids | |||
:param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids | |||
:param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask | |||
:param torch.Tensor target: [batch] | |||
:return: | |||
""" | |||
_, pooled_output = self.bert(words, seq_len1, seq_len2) | |||
logits = self.classifier(pooled_output) | |||
if target is not None: | |||
loss_func = torch.nn.CrossEntropyLoss() | |||
loss = loss_func(logits, target) | |||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
return {Const.OUTPUT: logits} | |||
def predict(self, words, seq_len1, seq_len2, target=None): | |||
return self.forward(words, seq_len1, seq_len2) | |||
def forward(self, words, segment_id, seq_len): | |||
pass |
@@ -0,0 +1,97 @@ | |||
import os | |||
import torch | |||
from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric | |||
from reproduction.matching.data.SNLIDataLoader import SNLILoader | |||
from legacy.component.bert_tokenizer import BertTokenizer | |||
from reproduction.matching.model.bert import BertForNLI | |||
def preprocess_data(data: DataSet, bert_dir): | |||
""" | |||
preprocess data set to bert-need data set. | |||
:param data: | |||
:param bert_dir: | |||
:return: | |||
""" | |||
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt')) | |||
vocab = Vocabulary(padding=None, unknown=None) | |||
with open(os.path.join(bert_dir, 'vocab.txt')) as f: | |||
lines = f.readlines() | |||
vocab_list = [] | |||
for line in lines: | |||
vocab_list.append(line.strip()) | |||
vocab.add_word_lst(vocab_list) | |||
vocab.build_vocab() | |||
vocab.padding = '[PAD]' | |||
vocab.unknown = '[UNK]' | |||
for i in range(2): | |||
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), | |||
new_field_name=Const.INPUTS(i)) | |||
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], | |||
new_field_name=Const.INPUT) | |||
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||
new_field_name=Const.INPUT_LENS(0)) | |||
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) | |||
max_len = 512 | |||
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) | |||
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) | |||
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) | |||
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment']) | |||
target_vocab.build_vocab() | |||
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) | |||
data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET) | |||
data.set_target(Const.TARGET) | |||
return data | |||
bert_dirs = 'path/to/bert/dir' | |||
# load raw data set | |||
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl') | |||
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl') | |||
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl') | |||
print('successfully load data sets!') | |||
train_data = preprocess_data(train_data, bert_dirs) | |||
dev_data = preprocess_data(dev_data, bert_dirs) | |||
test_data = preprocess_data(test_data, bert_dirs) | |||
model = BertForNLI(bert_dir=bert_dirs) | |||
trainer = Trainer( | |||
train_data=train_data, | |||
model=model, | |||
optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||
batch_size=torch.cuda.device_count() * 12, | |||
n_epochs=4, | |||
print_every=-1, | |||
dev_data=dev_data, | |||
metrics=AccuracyMetric(), | |||
metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1 | |||
) | |||
trainer.train(load_best_model=True) | |||
tester = Tester( | |||
data=test_data, | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * 12, | |||
device=[i for i in range(torch.cuda.device_count())], | |||
) | |||
tester.test() | |||