@@ -269,7 +269,7 @@ class MatchingLoader(DataSetLoader): | |||||
def _load(self, path: str) -> DataSet: | def _load(self, path: str) -> DataSet: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: | |||||
def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo: | |||||
if isinstance(paths, str): | if isinstance(paths, str): | ||||
paths = {'train': paths} | paths = {'train': paths} | ||||
@@ -289,6 +289,13 @@ class MatchingLoader(DataSetLoader): | |||||
raise RuntimeError(f'Your model is {self.data_format}, ' | raise RuntimeError(f'Your model is {self.data_format}, ' | ||||
f'Please choose from [esim, bert]') | f'Please choose from [esim, bert]') | ||||
if input_field is not None: | |||||
if isinstance(input_field, str): | |||||
data.set_input(input_field) | |||||
elif isinstance(input_field, list): | |||||
for field in input_field: | |||||
data.set_input(field) | |||||
data_set[n] = data | data_set[n] = data | ||||
print(f'successfully load {n} set!') | print(f'successfully load {n} set!') | ||||
@@ -298,11 +305,11 @@ class MatchingLoader(DataSetLoader): | |||||
raise RuntimeError(f'There is NOT label vocab attribute built!') | raise RuntimeError(f'There is NOT label vocab attribute built!') | ||||
if self.for_model != 'bert': | if self.for_model != 'bert': | ||||
from fastNLP.modules.encoder.embedding import StaticEmbedding | |||||
embedding = StaticEmbedding(self.vocab, model_dir_or_name='en') | |||||
from fastNLP.modules.encoder.embedding import ElmoEmbedding | |||||
embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2') | |||||
data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, | data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, | ||||
embeddings={'glove': embedding} if self.for_model != 'bert' else None, | |||||
embeddings={'elmo': embedding} if self.for_model != 'bert' else None, | |||||
datasets=data_set) | datasets=data_set) | ||||
return data_info | return data_info | ||||
@@ -338,15 +345,17 @@ class MatchingLoader(DataSetLoader): | |||||
raw_ds.drop(lambda x: x[Const.TARGET] == '-') | raw_ds.drop(lambda x: x[Const.TARGET] == '-') | ||||
if not hasattr(self, 'vocab'): | if not hasattr(self, 'vocab'): | ||||
self.vocab = Vocabulary().from_dataset(raw_ds, [Const.INPUTS(0), Const.INPUTS(1)]) | |||||
self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | |||||
if not hasattr(self, 'label_vocab'): | if not hasattr(self, 'label_vocab'): | ||||
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) | self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) | ||||
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) | raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) | ||||
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) | raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) | ||||
raw_ds.apply(lambda ins: self.label_vocab.to_index(Const.TARGET), new_field_name=Const.TARGET) | |||||
raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET) | |||||
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0)) | |||||
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1)) | |||||
raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||||
raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)) | |||||
raw_ds.set_target(Const.TARGET) | raw_ds.set_target(Const.TARGET) | ||||
return raw_ds | return raw_ds | ||||
@@ -405,6 +414,8 @@ class MatchingLoader(DataSetLoader): | |||||
raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) | raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) | ||||
raw_ds.set_target(Const.TARGET) | raw_ds.set_target(Const.TARGET) | ||||
return raw_ds | |||||
class SNLILoader(JsonLoader): | class SNLILoader(JsonLoader): | ||||
""" | """ | ||||
@@ -2,9 +2,9 @@ | |||||
import os | import os | ||||
from torch import nn | from torch import nn | ||||
import torch | import torch | ||||
from ...core import Vocabulary | |||||
from ...core.vocabulary import Vocabulary | |||||
from ...io.file_utils import _get_base_url, cached_path | from ...io.file_utils import _get_base_url, cached_path | ||||
from ._bert import _WordPieceBertModel | |||||
from ._bert import _WordPieceBertModel, BertModel | |||||
class BertWordPieceEncoder(nn.Module): | class BertWordPieceEncoder(nn.Module): | ||||
@@ -152,6 +152,8 @@ class StaticEmbedding(TokenEmbedding): | |||||
Example:: | Example:: | ||||
>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50') | |||||
:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | ||||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding | :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding | ||||
@@ -311,8 +313,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
Example:: | Example:: | ||||
>>> | |||||
>>> | |||||
>>> embedding = ElmoEmbedding(vocab, model_dir_or_name='en', layers='2', requires_grad=True) | |||||
:param vocab: 词表 | :param vocab: 词表 | ||||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, | :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, | ||||
@@ -403,7 +404,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
Example:: | Example:: | ||||
>>> | |||||
>>> embedding = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1') | |||||
:param fastNLP.Vocabulary vocab: 词表 | :param fastNLP.Vocabulary vocab: 词表 | ||||
@@ -513,7 +514,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
Example:: | Example:: | ||||
>>> | |||||
>>> cnn_char_embed = CNNCharEmbedding(vocab) | |||||
:param vocab: 词表 | :param vocab: 词表 | ||||
@@ -647,7 +648,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
Example:: | Example:: | ||||
>>> | |||||
>>> lstm_char_embed = LSTMCharEmbedding(vocab) | |||||
:param vocab: 词表 | :param vocab: 词表 | ||||
:param embed_size: embedding的大小。默认值为50. | :param embed_size: embedding的大小。默认值为50. | ||||
@@ -2,31 +2,31 @@ import os | |||||
import torch | import torch | ||||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric | |||||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||||
from fastNLP.io.dataset_loader import MatchingLoader | from fastNLP.io.dataset_loader import MatchingLoader | ||||
from reproduction.matching.model.bert import BertForNLI | from reproduction.matching.model.bert import BertForNLI | ||||
from reproduction.matching.model.esim import ESIMModel | |||||
# bert_dirs = 'path/to/bert/dir' | |||||
bert_dirs = '/remote-home/ygxu/BERT/BERT_English_uncased_L-12_H-768_A_12' | |||||
bert_dirs = 'path/to/bert/dir' | |||||
# load data set | # load data set | ||||
data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process( | |||||
{#'train': './data/snli/snli_1.0_train.jsonl', | |||||
# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(... | |||||
data_info = MatchingLoader(data_format='snli', for_model='esim').process( | |||||
{'train': './data/snli/snli_1.0_train.jsonl', | |||||
'dev': './data/snli/snli_1.0_dev.jsonl', | 'dev': './data/snli/snli_1.0_dev.jsonl', | ||||
'test': './data/snli/snli_1.0_test.jsonl'} | |||||
'test': './data/snli/snli_1.0_test.jsonl'}, | |||||
input_field=[Const.TARGET] | |||||
) | ) | ||||
print('successfully load data sets!') | |||||
# model = BertForNLI(bert_dir=bert_dirs) | |||||
model = ESIMModel(data_info.embeddings['elmo'],) | |||||
model = BertForNLI(bert_dir=bert_dirs) | |||||
trainer = Trainer(train_data=data_info.datasets['dev'], model=model, | |||||
optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, | |||||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||||
optimizer=Adam(lr=1e-4, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1, | |||||
dev_data=data_info.datasets['dev'], | dev_data=data_info.datasets['dev'], | ||||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | ||||
check_code_level=-1) | check_code_level=-1) | ||||
@@ -0,0 +1,182 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torch.nn import CrossEntropyLoss | |||||
from fastNLP.models import BaseModel | |||||
from fastNLP.modules.encoder.embedding import TokenEmbedding | |||||
from fastNLP.modules.encoder.lstm import LSTM | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.core.utils import seq_len_to_mask | |||||
class ESIMModel(BaseModel): | |||||
def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, | |||||
dropout_embed=0.1): | |||||
super(ESIMModel, self).__init__() | |||||
self.embedding = init_embedding | |||||
self.dropout_embed = EmbedDropout(p=dropout_embed) | |||||
if hidden_size is None: | |||||
hidden_size = self.embedding.embed_size | |||||
self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||||
# self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) | |||||
self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), | |||||
nn.Linear(8 * hidden_size, hidden_size), | |||||
nn.ReLU()) | |||||
nn.init.xavier_uniform_(self.interfere[1].weight.data) | |||||
self.bi_attention = SoftmaxAttention() | |||||
self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||||
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True) | |||||
self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), | |||||
nn.Linear(8 * hidden_size, hidden_size), | |||||
nn.Tanh(), | |||||
nn.Dropout(p=dropout_rate), | |||||
nn.Linear(hidden_size, num_labels)) | |||||
nn.init.xavier_uniform_(self.classifier[1].weight.data) | |||||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | |||||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||||
mask1 = seq_len_to_mask(seq_len1) | |||||
mask2 = seq_len_to_mask(seq_len2) | |||||
a0 = self.embedding(words1) # B * len * emb_dim | |||||
b0 = self.embedding(words2) | |||||
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) | |||||
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] | |||||
b = self.rnn(b0, mask2.byte()) | |||||
ai, bi = self.bi_attention(a, mask1, b, mask2) | |||||
a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] | |||||
b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) | |||||
a_f = self.interfere(a_) | |||||
b_f = self.interfere(b_) | |||||
a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] | |||||
b_h = self.rnn_high(b_f, mask2.byte()) | |||||
a_avg = self.mean_pooling(a_h, mask1, dim=1) | |||||
a_max, _ = self.max_pooling(a_h, mask1, dim=1) | |||||
b_avg = self.mean_pooling(b_h, mask2, dim=1) | |||||
b_max, _ = self.max_pooling(b_h, mask2, dim=1) | |||||
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | |||||
logits = torch.tanh(self.classifier(out)) | |||||
if target is not None: | |||||
loss_fct = CrossEntropyLoss() | |||||
loss = loss_fct(logits, target) | |||||
return {Const.LOSS: loss, Const.OUTPUT: logits} | |||||
else: | |||||
return {Const.OUTPUT: logits} | |||||
def predict(self, **kwargs): | |||||
return self.forward(**kwargs) | |||||
# input [batch_size, len , hidden] | |||||
# mask [batch_size, len] (111...00) | |||||
@staticmethod | |||||
def mean_pooling(input, mask, dim=1): | |||||
masks = mask.view(mask.size(0), mask.size(1), -1).float() | |||||
return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) | |||||
@staticmethod | |||||
def max_pooling(input, mask, dim=1): | |||||
my_inf = 10e12 | |||||
masks = mask.view(mask.size(0), mask.size(1), -1) | |||||
masks = masks.expand(-1, -1, input.size(2)).float() | |||||
return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) | |||||
class EmbedDropout(nn.Dropout): | |||||
def forward(self, sequences_batch): | |||||
ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) | |||||
dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) | |||||
return dropout_mask.unsqueeze(1) * sequences_batch | |||||
class BiRNN(nn.Module): | |||||
def __init__(self, input_size, hidden_size, dropout_rate=0.3): | |||||
super(BiRNN, self).__init__() | |||||
self.dropout_rate = dropout_rate | |||||
self.rnn = nn.LSTM(input_size, hidden_size, | |||||
num_layers=1, | |||||
bidirectional=True, | |||||
batch_first=True) | |||||
def forward(self, x, x_mask): | |||||
# Sort x | |||||
lengths = x_mask.data.eq(1).long().sum(1).squeeze() | |||||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | |||||
_, idx_unsort = torch.sort(idx_sort, dim=0) | |||||
lengths = list(lengths[idx_sort]) | |||||
x = x.index_select(0, idx_sort) | |||||
# Pack it up | |||||
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) | |||||
# Apply dropout to input | |||||
if self.dropout_rate > 0: | |||||
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) | |||||
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) | |||||
output = self.rnn(rnn_input)[0] | |||||
# Unpack everything | |||||
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] | |||||
output = output.index_select(0, idx_unsort) | |||||
if output.size(1) != x_mask.size(1): | |||||
padding = torch.zeros(output.size(0), | |||||
x_mask.size(1) - output.size(1), | |||||
output.size(2)).type(output.data.type()) | |||||
output = torch.cat([output, padding], 1) | |||||
return output | |||||
def masked_softmax(tensor, mask): | |||||
tensor_shape = tensor.size() | |||||
reshaped_tensor = tensor.view(-1, tensor_shape[-1]) | |||||
# Reshape the mask so it matches the size of the input tensor. | |||||
while mask.dim() < tensor.dim(): | |||||
mask = mask.unsqueeze(1) | |||||
mask = mask.expand_as(tensor).contiguous().float() | |||||
reshaped_mask = mask.view(-1, mask.size()[-1]) | |||||
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) | |||||
result = result * reshaped_mask | |||||
# 1e-13 is added to avoid divisions by zero. | |||||
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) | |||||
return result.view(*tensor_shape) | |||||
def weighted_sum(tensor, weights, mask): | |||||
w_sum = weights.bmm(tensor) | |||||
while mask.dim() < w_sum.dim(): | |||||
mask = mask.unsqueeze(1) | |||||
mask = mask.transpose(-1, -2) | |||||
mask = mask.expand_as(w_sum).contiguous().float() | |||||
return w_sum * mask | |||||
class SoftmaxAttention(nn.Module): | |||||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||||
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) | |||||
.contiguous()) | |||||
prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) | |||||
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) | |||||
.contiguous(), | |||||
premise_mask) | |||||
attended_premises = weighted_sum(hypothesis_batch, | |||||
prem_hyp_attn, | |||||
premise_mask) | |||||
attended_hypotheses = weighted_sum(premise_batch, | |||||
hyp_prem_attn, | |||||
hypothesis_mask) | |||||
return attended_premises, attended_hypotheses |