@@ -16,6 +16,7 @@ __all__ = [ | |||
'CSVLoader', | |||
'JsonLoader', | |||
'ConllLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'SSTLoader', | |||
'PeopleDailyCorpusLoader', | |||
@@ -26,6 +27,6 @@ __all__ = [ | |||
] | |||
from .embed_loader import EmbedLoader | |||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||
PeopleDailyCorpusLoader, Conll2003Loader | |||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\ | |||
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | |||
from .model_io import ModelLoader, ModelSaver |
@@ -16,19 +16,24 @@ __all__ = [ | |||
'CSVLoader', | |||
'JsonLoader', | |||
'ConllLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'SSTLoader', | |||
'PeopleDailyCorpusLoader', | |||
'Conll2003Loader', | |||
] | |||
import os | |||
from nltk import Tree | |||
from typing import Union, Dict | |||
from ..core.vocabulary import Vocabulary | |||
from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
from .file_reader import _read_csv, _read_json, _read_conll | |||
from .base_loader import DataSetLoader | |||
from .base_loader import DataSetLoader, DataInfo | |||
from .data_loader.sst import SSTLoader | |||
from ..core.const import Const | |||
from ..modules.encoder._bert import BertTokenizer | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
@@ -245,6 +250,173 @@ class JsonLoader(DataSetLoader): | |||
return ds | |||
class MatchingLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||
读取Matching数据集,根据数据集做预处理并返回DataInfo。 | |||
数据来源: | |||
SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None): | |||
super(MatchingLoader, self).__init__() | |||
self.data_format = data_format.lower() | |||
self.for_model = for_model.lower() | |||
self.bert_dir = bert_dir | |||
def _load(self, path: str) -> DataSet: | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo: | |||
if isinstance(paths, str): | |||
paths = {'train': paths} | |||
data_set = {} | |||
for n, p in paths.items(): | |||
if self.data_format == 'snli': | |||
data = self._load_snli(p) | |||
else: | |||
raise RuntimeError(f'Your data format is {self.data_format}, ' | |||
f'Please choose data format from [snli]') | |||
if self.for_model == 'esim': | |||
data = self._for_esim(data) | |||
elif self.for_model == 'bert': | |||
data = self._for_bert(data, self.bert_dir) | |||
else: | |||
raise RuntimeError(f'Your model is {self.data_format}, ' | |||
f'Please choose from [esim, bert]') | |||
if input_field is not None: | |||
if isinstance(input_field, str): | |||
data.set_input(input_field) | |||
elif isinstance(input_field, list): | |||
for field in input_field: | |||
data.set_input(field) | |||
data_set[n] = data | |||
print(f'successfully load {n} set!') | |||
if not hasattr(self, 'vocab'): | |||
raise RuntimeError(f'There is NOT vocab attribute built!') | |||
if not hasattr(self, 'label_vocab'): | |||
raise RuntimeError(f'There is NOT label vocab attribute built!') | |||
if self.for_model != 'bert': | |||
from fastNLP.modules.encoder.embedding import ElmoEmbedding | |||
embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2') | |||
data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, | |||
embeddings={'elmo': embedding} if self.for_model != 'bert' else None, | |||
datasets=data_set) | |||
return data_info | |||
@staticmethod | |||
def _load_snli(path: str) -> DataSet: | |||
""" | |||
读取SNLI数据集 | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
:param str path: 数据集路径 | |||
:return: | |||
""" | |||
raw_ds = JsonLoader( | |||
fields={ | |||
'sentence1_parse': Const.INPUTS(0), | |||
'sentence2_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
)._load(path) | |||
return raw_ds | |||
def _for_esim(self, raw_ds: DataSet): | |||
if self.data_format == 'snli' or self.data_format == 'mnli': | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
raw_ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||
raw_ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||
raw_ds.drop(lambda x: x[Const.TARGET] == '-') | |||
if not hasattr(self, 'vocab'): | |||
self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | |||
if not hasattr(self, 'label_vocab'): | |||
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) | |||
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) | |||
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) | |||
raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET) | |||
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0)) | |||
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1)) | |||
raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)) | |||
raw_ds.set_target(Const.TARGET) | |||
return raw_ds | |||
def _for_bert(self, raw_ds: DataSet, bert_dir: str): | |||
if self.data_format == 'snli' or self.data_format == 'mnli': | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
raw_ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||
raw_ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||
raw_ds.drop(lambda x: x[Const.TARGET] == '-') | |||
tokenizer = BertTokenizer.from_pretrained(bert_dir) | |||
vocab = Vocabulary(padding=None, unknown=None) | |||
with open(os.path.join(bert_dir, 'vocab.txt')) as f: | |||
lines = f.readlines() | |||
vocab_list = [] | |||
for line in lines: | |||
vocab_list.append(line.strip()) | |||
vocab.add_word_lst(vocab_list) | |||
vocab.build_vocab() | |||
vocab.padding = '[PAD]' | |||
vocab.unknown = '[UNK]' | |||
if not hasattr(self, 'vocab'): | |||
self.vocab = vocab | |||
else: | |||
for w, idx in self.vocab: | |||
if vocab[w] != idx: | |||
raise AttributeError(f"{self.__class__.__name__} has ") | |||
for i in range(2): | |||
raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i)) | |||
raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], | |||
new_field_name=Const.INPUT) | |||
raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||
new_field_name=Const.INPUT_LENS(0)) | |||
raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) | |||
max_len = 512 | |||
raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) | |||
raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) | |||
raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) | |||
raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) | |||
if not hasattr(self, 'label_vocab'): | |||
self.label_vocab = Vocabulary(padding=None, unknown=None) | |||
self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET) | |||
raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) | |||
raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) | |||
raw_ds.set_target(Const.TARGET) | |||
return raw_ds | |||
class SNLILoader(JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
@@ -7,6 +7,12 @@ __all__ = [ | |||
"ConvMaxpool", | |||
"Embedding", | |||
"StaticEmbedding", | |||
"ElmoEmbedding", | |||
"BertEmbedding", | |||
"StackEmbedding", | |||
"LSTMCharEmbedding", | |||
"CNNCharEmbedding", | |||
"LSTM", | |||
@@ -22,7 +28,8 @@ from ._bert import BertModel | |||
from .bert import BertWordPieceEncoder | |||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||
from .conv_maxpool import ConvMaxpool | |||
from .embedding import Embedding | |||
from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ | |||
StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding | |||
from .lstm import LSTM | |||
from .star_transformer import StarTransformer | |||
from .transformer import TransformerEncoder | |||
@@ -7,7 +7,7 @@ | |||
from ... import Vocabulary | |||
from ...core.vocabulary import Vocabulary | |||
import collections | |||
import unicodedata | |||
@@ -2,9 +2,9 @@ | |||
import os | |||
from torch import nn | |||
import torch | |||
from ...core import Vocabulary | |||
from ...core.vocabulary import Vocabulary | |||
from ...io.file_utils import _get_base_url, cached_path | |||
from ._bert import _WordPieceBertModel | |||
from ._bert import _WordPieceBertModel, BertModel | |||
class BertWordPieceEncoder(nn.Module): | |||
@@ -1,10 +1,16 @@ | |||
__all__ = [ | |||
"Embedding" | |||
"Embedding", | |||
"StaticEmbedding", | |||
"ElmoEmbedding", | |||
"BertEmbedding", | |||
"StackEmbedding", | |||
"LSTMCharEmbedding", | |||
"CNNCharEmbedding", | |||
] | |||
import torch.nn as nn | |||
from ..utils import get_embeddings | |||
from .lstm import LSTM | |||
from ... import Vocabulary | |||
from ...core.vocabulary import Vocabulary | |||
from abc import abstractmethod | |||
import torch | |||
from ...io import EmbedLoader | |||
@@ -15,7 +21,9 @@ from ...io.file_utils import cached_path, _get_base_url | |||
from ._bert import _WordBertModel | |||
from typing import List | |||
from ... import DataSet, DataSetIter, SequentialSampler | |||
from ...core.dataset import DataSet | |||
from ...core.batch import DataSetIter | |||
from ...core.sampler import SequentialSampler | |||
from ...core.utils import _move_model_to_device, _get_model_device | |||
@@ -144,6 +152,8 @@ class StaticEmbedding(TokenEmbedding): | |||
Example:: | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50') | |||
:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding | |||
@@ -303,8 +313,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
Example:: | |||
>>> | |||
>>> | |||
>>> embedding = ElmoEmbedding(vocab, model_dir_or_name='en', layers='2', requires_grad=True) | |||
:param vocab: 词表 | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, | |||
@@ -395,7 +404,7 @@ class BertEmbedding(ContextualEmbedding): | |||
Example:: | |||
>>> | |||
>>> embedding = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1') | |||
:param fastNLP.Vocabulary vocab: 词表 | |||
@@ -505,7 +514,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||
Example:: | |||
>>> | |||
>>> cnn_char_embed = CNNCharEmbedding(vocab) | |||
:param vocab: 词表 | |||
@@ -641,7 +650,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
Example:: | |||
>>> | |||
>>> lstm_char_embed = LSTMCharEmbedding(vocab) | |||
:param vocab: 词表 | |||
:param embed_size: embedding的大小。默认值为50. | |||
@@ -0,0 +1,44 @@ | |||
import os | |||
import torch | |||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||
from fastNLP.io.dataset_loader import MatchingLoader | |||
from reproduction.matching.model.bert import BertForNLI | |||
from reproduction.matching.model.esim import ESIMModel | |||
bert_dirs = 'path/to/bert/dir' | |||
# load data set | |||
# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(... | |||
data_info = MatchingLoader(data_format='snli', for_model='esim').process( | |||
{'train': './data/snli/snli_1.0_train.jsonl', | |||
'dev': './data/snli/snli_1.0_dev.jsonl', | |||
'test': './data/snli/snli_1.0_test.jsonl'}, | |||
input_field=[Const.TARGET] | |||
) | |||
# model = BertForNLI(bert_dir=bert_dirs) | |||
model = ESIMModel(data_info.embeddings['elmo'],) | |||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||
optimizer=Adam(lr=1e-4, model_params=model.parameters()), | |||
batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1, | |||
dev_data=data_info.datasets['dev'], | |||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1) | |||
trainer.train(load_best_model=True) | |||
tester = Tester( | |||
data=data_info.datasets['test'], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * 12, | |||
device=[i for i in range(torch.cuda.device_count())], | |||
) | |||
tester.test() | |||
@@ -0,0 +1,182 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.nn import CrossEntropyLoss | |||
from fastNLP.models import BaseModel | |||
from fastNLP.modules.encoder.embedding import TokenEmbedding | |||
from fastNLP.modules.encoder.lstm import LSTM | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.utils import seq_len_to_mask | |||
class ESIMModel(BaseModel): | |||
def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, | |||
dropout_embed=0.1): | |||
super(ESIMModel, self).__init__() | |||
self.embedding = init_embedding | |||
self.dropout_embed = EmbedDropout(p=dropout_embed) | |||
if hidden_size is None: | |||
hidden_size = self.embedding.embed_size | |||
self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||
# self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) | |||
self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), | |||
nn.Linear(8 * hidden_size, hidden_size), | |||
nn.ReLU()) | |||
nn.init.xavier_uniform_(self.interfere[1].weight.data) | |||
self.bi_attention = SoftmaxAttention() | |||
self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True) | |||
self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), | |||
nn.Linear(8 * hidden_size, hidden_size), | |||
nn.Tanh(), | |||
nn.Dropout(p=dropout_rate), | |||
nn.Linear(hidden_size, num_labels)) | |||
nn.init.xavier_uniform_(self.classifier[1].weight.data) | |||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | |||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||
mask1 = seq_len_to_mask(seq_len1) | |||
mask2 = seq_len_to_mask(seq_len2) | |||
a0 = self.embedding(words1) # B * len * emb_dim | |||
b0 = self.embedding(words2) | |||
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) | |||
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] | |||
b = self.rnn(b0, mask2.byte()) | |||
ai, bi = self.bi_attention(a, mask1, b, mask2) | |||
a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] | |||
b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) | |||
a_f = self.interfere(a_) | |||
b_f = self.interfere(b_) | |||
a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] | |||
b_h = self.rnn_high(b_f, mask2.byte()) | |||
a_avg = self.mean_pooling(a_h, mask1, dim=1) | |||
a_max, _ = self.max_pooling(a_h, mask1, dim=1) | |||
b_avg = self.mean_pooling(b_h, mask2, dim=1) | |||
b_max, _ = self.max_pooling(b_h, mask2, dim=1) | |||
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | |||
logits = torch.tanh(self.classifier(out)) | |||
if target is not None: | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct(logits, target) | |||
return {Const.LOSS: loss, Const.OUTPUT: logits} | |||
else: | |||
return {Const.OUTPUT: logits} | |||
def predict(self, **kwargs): | |||
return self.forward(**kwargs) | |||
# input [batch_size, len , hidden] | |||
# mask [batch_size, len] (111...00) | |||
@staticmethod | |||
def mean_pooling(input, mask, dim=1): | |||
masks = mask.view(mask.size(0), mask.size(1), -1).float() | |||
return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) | |||
@staticmethod | |||
def max_pooling(input, mask, dim=1): | |||
my_inf = 10e12 | |||
masks = mask.view(mask.size(0), mask.size(1), -1) | |||
masks = masks.expand(-1, -1, input.size(2)).float() | |||
return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) | |||
class EmbedDropout(nn.Dropout): | |||
def forward(self, sequences_batch): | |||
ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) | |||
dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) | |||
return dropout_mask.unsqueeze(1) * sequences_batch | |||
class BiRNN(nn.Module): | |||
def __init__(self, input_size, hidden_size, dropout_rate=0.3): | |||
super(BiRNN, self).__init__() | |||
self.dropout_rate = dropout_rate | |||
self.rnn = nn.LSTM(input_size, hidden_size, | |||
num_layers=1, | |||
bidirectional=True, | |||
batch_first=True) | |||
def forward(self, x, x_mask): | |||
# Sort x | |||
lengths = x_mask.data.eq(1).long().sum(1).squeeze() | |||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | |||
_, idx_unsort = torch.sort(idx_sort, dim=0) | |||
lengths = list(lengths[idx_sort]) | |||
x = x.index_select(0, idx_sort) | |||
# Pack it up | |||
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) | |||
# Apply dropout to input | |||
if self.dropout_rate > 0: | |||
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) | |||
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) | |||
output = self.rnn(rnn_input)[0] | |||
# Unpack everything | |||
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] | |||
output = output.index_select(0, idx_unsort) | |||
if output.size(1) != x_mask.size(1): | |||
padding = torch.zeros(output.size(0), | |||
x_mask.size(1) - output.size(1), | |||
output.size(2)).type(output.data.type()) | |||
output = torch.cat([output, padding], 1) | |||
return output | |||
def masked_softmax(tensor, mask): | |||
tensor_shape = tensor.size() | |||
reshaped_tensor = tensor.view(-1, tensor_shape[-1]) | |||
# Reshape the mask so it matches the size of the input tensor. | |||
while mask.dim() < tensor.dim(): | |||
mask = mask.unsqueeze(1) | |||
mask = mask.expand_as(tensor).contiguous().float() | |||
reshaped_mask = mask.view(-1, mask.size()[-1]) | |||
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) | |||
result = result * reshaped_mask | |||
# 1e-13 is added to avoid divisions by zero. | |||
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) | |||
return result.view(*tensor_shape) | |||
def weighted_sum(tensor, weights, mask): | |||
w_sum = weights.bmm(tensor) | |||
while mask.dim() < w_sum.dim(): | |||
mask = mask.unsqueeze(1) | |||
mask = mask.transpose(-1, -2) | |||
mask = mask.expand_as(w_sum).contiguous().float() | |||
return w_sum * mask | |||
class SoftmaxAttention(nn.Module): | |||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) | |||
.contiguous()) | |||
prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) | |||
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) | |||
.contiguous(), | |||
premise_mask) | |||
attended_premises = weighted_sum(hypothesis_batch, | |||
prem_hyp_attn, | |||
premise_mask) | |||
attended_hypotheses = weighted_sum(premise_batch, | |||
hyp_prem_attn, | |||
hypothesis_mask) | |||
return attended_premises, attended_hypotheses |
@@ -1,88 +0,0 @@ | |||
import os | |||
import torch | |||
from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric | |||
from reproduction.matching.data.SNLIDataLoader import SNLILoader | |||
from legacy.component.bert_tokenizer import BertTokenizer | |||
from reproduction.matching.model.bert import BertForNLI | |||
def preprocess_data(data: DataSet, bert_dir): | |||
""" | |||
preprocess data set to bert-need data set. | |||
:param data: | |||
:param bert_dir: | |||
:return: | |||
""" | |||
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt')) | |||
vocab = Vocabulary(padding=None, unknown=None) | |||
with open(os.path.join(bert_dir, 'vocab.txt')) as f: | |||
lines = f.readlines() | |||
vocab_list = [] | |||
for line in lines: | |||
vocab_list.append(line.strip()) | |||
vocab.add_word_lst(vocab_list) | |||
vocab.build_vocab() | |||
vocab.padding = '[PAD]' | |||
vocab.unknown = '[UNK]' | |||
for i in range(2): | |||
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), | |||
new_field_name=Const.INPUTS(i)) | |||
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], | |||
new_field_name=Const.INPUT) | |||
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||
new_field_name=Const.INPUT_LENS(0)) | |||
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) | |||
max_len = 512 | |||
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) | |||
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) | |||
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) | |||
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment']) | |||
target_vocab.build_vocab() | |||
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) | |||
data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET) | |||
data.set_target(Const.TARGET) | |||
return data | |||
bert_dirs = 'path/to/bert/dir' | |||
# load raw data set | |||
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl') | |||
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl') | |||
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl') | |||
print('successfully load data sets!') | |||
train_data = preprocess_data(train_data, bert_dirs) | |||
dev_data = preprocess_data(dev_data, bert_dirs) | |||
test_data = preprocess_data(test_data, bert_dirs) | |||
model = BertForNLI(bert_dir=bert_dirs) | |||
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data, | |||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1) | |||
trainer.train(load_best_model=True) | |||
tester = Tester( | |||
data=test_data, | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * 12, | |||
device=[i for i in range(torch.cuda.device_count())], | |||
) | |||
tester.test() | |||