Browse Source

Merge branch 'dev0.5.0' of github.com:fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
yh_cc 5 years ago
parent
commit
4d138ed7f8
9 changed files with 430 additions and 103 deletions
  1. +3
    -2
      fastNLP/io/__init__.py
  2. +173
    -1
      fastNLP/io/dataset_loader.py
  3. +8
    -1
      fastNLP/modules/encoder/__init__.py
  4. +1
    -1
      fastNLP/modules/encoder/_bert.py
  5. +2
    -2
      fastNLP/modules/encoder/bert.py
  6. +17
    -8
      fastNLP/modules/encoder/embedding.py
  7. +44
    -0
      reproduction/matching/matching.py
  8. +182
    -0
      reproduction/matching/model/esim.py
  9. +0
    -88
      reproduction/matching/snli.py

+ 3
- 2
fastNLP/io/__init__.py View File

@@ -16,6 +16,7 @@ __all__ = [
'CSVLoader', 'CSVLoader',
'JsonLoader', 'JsonLoader',
'ConllLoader', 'ConllLoader',
'MatchingLoader',
'SNLILoader', 'SNLILoader',
'SSTLoader', 'SSTLoader',
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
@@ -26,6 +27,6 @@ __all__ = [
] ]


from .embed_loader import EmbedLoader from .embed_loader import EmbedLoader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \
PeopleDailyCorpusLoader, Conll2003Loader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader, ModelSaver from .model_io import ModelLoader, ModelSaver

+ 173
- 1
fastNLP/io/dataset_loader.py View File

@@ -16,19 +16,24 @@ __all__ = [
'CSVLoader', 'CSVLoader',
'JsonLoader', 'JsonLoader',
'ConllLoader', 'ConllLoader',
'MatchingLoader',
'SNLILoader', 'SNLILoader',
'SSTLoader', 'SSTLoader',
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
'Conll2003Loader', 'Conll2003Loader',
] ]


import os
from nltk import Tree from nltk import Tree
from typing import Union, Dict
from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.instance import Instance from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader
from .base_loader import DataSetLoader, DataInfo
from .data_loader.sst import SSTLoader from .data_loader.sst import SSTLoader
from ..core.const import Const from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer




class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
@@ -245,6 +250,173 @@ class JsonLoader(DataSetLoader):
return ds return ds




class MatchingLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`

读取Matching数据集,根据数据集做预处理并返回DataInfo。

数据来源:
SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""

def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None):
super(MatchingLoader, self).__init__()
self.data_format = data_format.lower()
self.for_model = for_model.lower()
self.bert_dir = bert_dir

def _load(self, path: str) -> DataSet:
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo:
if isinstance(paths, str):
paths = {'train': paths}

data_set = {}
for n, p in paths.items():
if self.data_format == 'snli':
data = self._load_snli(p)
else:
raise RuntimeError(f'Your data format is {self.data_format}, '
f'Please choose data format from [snli]')

if self.for_model == 'esim':
data = self._for_esim(data)
elif self.for_model == 'bert':
data = self._for_bert(data, self.bert_dir)
else:
raise RuntimeError(f'Your model is {self.data_format}, '
f'Please choose from [esim, bert]')

if input_field is not None:
if isinstance(input_field, str):
data.set_input(input_field)
elif isinstance(input_field, list):
for field in input_field:
data.set_input(field)

data_set[n] = data
print(f'successfully load {n} set!')

if not hasattr(self, 'vocab'):
raise RuntimeError(f'There is NOT vocab attribute built!')
if not hasattr(self, 'label_vocab'):
raise RuntimeError(f'There is NOT label vocab attribute built!')

if self.for_model != 'bert':
from fastNLP.modules.encoder.embedding import ElmoEmbedding
embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2')

data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab},
embeddings={'elmo': embedding} if self.for_model != 'bert' else None,
datasets=data_set)

return data_info

@staticmethod
def _load_snli(path: str) -> DataSet:
"""
读取SNLI数据集

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
:param str path: 数据集路径
:return:
"""
raw_ds = JsonLoader(
fields={
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
)._load(path)
return raw_ds

def _for_esim(self, raw_ds: DataSet):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

if not hasattr(self, 'vocab'):
self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)])
if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET)

raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1))
raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET)
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1))

raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1))
raw_ds.set_target(Const.TARGET)

return raw_ds

def _for_bert(self, raw_ds: DataSet, bert_dir: str):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

tokenizer = BertTokenizer.from_pretrained(bert_dir)

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

if not hasattr(self, 'vocab'):
self.vocab = vocab
else:
for w, idx in self.vocab:
if vocab[w] != idx:
raise AttributeError(f"{self.__class__.__name__} has ")

for i in range(2):
raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i))
raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None)
self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET)
raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1))
raw_ds.set_target(Const.TARGET)

return raw_ds


class SNLILoader(JsonLoader): class SNLILoader(JsonLoader):
""" """
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`


+ 8
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -7,6 +7,12 @@ __all__ = [
"ConvMaxpool", "ConvMaxpool",
"Embedding", "Embedding",
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
"LSTM", "LSTM",
@@ -22,7 +28,8 @@ from ._bert import BertModel
from .bert import BertWordPieceEncoder from .bert import BertWordPieceEncoder
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \
StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding
from .lstm import LSTM from .lstm import LSTM
from .star_transformer import StarTransformer from .star_transformer import StarTransformer
from .transformer import TransformerEncoder from .transformer import TransformerEncoder


+ 1
- 1
fastNLP/modules/encoder/_bert.py View File

@@ -7,7 +7,7 @@






from ... import Vocabulary
from ...core.vocabulary import Vocabulary
import collections import collections


import unicodedata import unicodedata


+ 2
- 2
fastNLP/modules/encoder/bert.py View File

@@ -2,9 +2,9 @@
import os import os
from torch import nn from torch import nn
import torch import torch
from ...core import Vocabulary
from ...core.vocabulary import Vocabulary
from ...io.file_utils import _get_base_url, cached_path from ...io.file_utils import _get_base_url, cached_path
from ._bert import _WordPieceBertModel
from ._bert import _WordPieceBertModel, BertModel




class BertWordPieceEncoder(nn.Module): class BertWordPieceEncoder(nn.Module):


+ 17
- 8
fastNLP/modules/encoder/embedding.py View File

@@ -1,10 +1,16 @@
__all__ = [ __all__ = [
"Embedding"
"Embedding",
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
] ]
import torch.nn as nn import torch.nn as nn
from ..utils import get_embeddings from ..utils import get_embeddings
from .lstm import LSTM from .lstm import LSTM
from ... import Vocabulary
from ...core.vocabulary import Vocabulary
from abc import abstractmethod from abc import abstractmethod
import torch import torch
from ...io import EmbedLoader from ...io import EmbedLoader
@@ -15,7 +21,9 @@ from ...io.file_utils import cached_path, _get_base_url
from ._bert import _WordBertModel from ._bert import _WordBertModel
from typing import List from typing import List


from ... import DataSet, DataSetIter, SequentialSampler
from ...core.dataset import DataSet
from ...core.batch import DataSetIter
from ...core.sampler import SequentialSampler
from ...core.utils import _move_model_to_device, _get_model_device from ...core.utils import _move_model_to_device, _get_model_device




@@ -144,6 +152,8 @@ class StaticEmbedding(TokenEmbedding):


Example:: Example::


>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50')



:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding
@@ -303,8 +313,7 @@ class ElmoEmbedding(ContextualEmbedding):


Example:: Example::


>>>
>>>
>>> embedding = ElmoEmbedding(vocab, model_dir_or_name='en', layers='2', requires_grad=True)


:param vocab: 词表 :param vocab: 词表
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称,
@@ -395,7 +404,7 @@ class BertEmbedding(ContextualEmbedding):


Example:: Example::


>>>
>>> embedding = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1')




:param fastNLP.Vocabulary vocab: 词表 :param fastNLP.Vocabulary vocab: 词表
@@ -505,7 +514,7 @@ class CNNCharEmbedding(TokenEmbedding):


Example:: Example::


>>>
>>> cnn_char_embed = CNNCharEmbedding(vocab)




:param vocab: 词表 :param vocab: 词表
@@ -641,7 +650,7 @@ class LSTMCharEmbedding(TokenEmbedding):


Example:: Example::


>>>
>>> lstm_char_embed = LSTMCharEmbedding(vocab)


:param vocab: 词表 :param vocab: 词表
:param embed_size: embedding的大小。默认值为50. :param embed_size: embedding的大小。默认值为50.


+ 44
- 0
reproduction/matching/matching.py View File

@@ -0,0 +1,44 @@
import os

import torch

from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const

from fastNLP.io.dataset_loader import MatchingLoader

from reproduction.matching.model.bert import BertForNLI
from reproduction.matching.model.esim import ESIMModel


bert_dirs = 'path/to/bert/dir'

# load data set
# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(...
data_info = MatchingLoader(data_format='snli', for_model='esim').process(
{'train': './data/snli/snli_1.0_train.jsonl',
'dev': './data/snli/snli_1.0_dev.jsonl',
'test': './data/snli/snli_1.0_test.jsonl'},
input_field=[Const.TARGET]
)

# model = BertForNLI(bert_dir=bert_dirs)
model = ESIMModel(data_info.embeddings['elmo'],)

trainer = Trainer(train_data=data_info.datasets['train'], model=model,
optimizer=Adam(lr=1e-4, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1,
dev_data=data_info.datasets['dev'],
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=data_info.datasets['test'],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



+ 182
- 0
reproduction/matching/model/esim.py View File

@@ -0,0 +1,182 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import CrossEntropyLoss

from fastNLP.models import BaseModel
from fastNLP.modules.encoder.embedding import TokenEmbedding
from fastNLP.modules.encoder.lstm import LSTM
from fastNLP.core.const import Const
from fastNLP.core.utils import seq_len_to_mask


class ESIMModel(BaseModel):
def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3,
dropout_embed=0.1):
super(ESIMModel, self).__init__()

self.embedding = init_embedding
self.dropout_embed = EmbedDropout(p=dropout_embed)
if hidden_size is None:
hidden_size = self.embedding.embed_size
self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate)
# self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True)

self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate),
nn.Linear(8 * hidden_size, hidden_size),
nn.ReLU())
nn.init.xavier_uniform_(self.interfere[1].weight.data)
self.bi_attention = SoftmaxAttention()

self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate)
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True)

self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate),
nn.Linear(8 * hidden_size, hidden_size),
nn.Tanh(),
nn.Dropout(p=dropout_rate),
nn.Linear(hidden_size, num_labels))
nn.init.xavier_uniform_(self.classifier[1].weight.data)
nn.init.xavier_uniform_(self.classifier[4].weight.data)

def forward(self, words1, words2, seq_len1, seq_len2, target=None):
mask1 = seq_len_to_mask(seq_len1)
mask2 = seq_len_to_mask(seq_len2)
a0 = self.embedding(words1) # B * len * emb_dim
b0 = self.embedding(words2)
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0)
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H]
b = self.rnn(b0, mask2.byte())

ai, bi = self.bi_attention(a, mask1, b, mask2)

a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H]
b_ = torch.cat((b, bi, b - bi, b * bi), dim=2)
a_f = self.interfere(a_)
b_f = self.interfere(b_)

a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H]
b_h = self.rnn_high(b_f, mask2.byte())

a_avg = self.mean_pooling(a_h, mask1, dim=1)
a_max, _ = self.max_pooling(a_h, mask1, dim=1)
b_avg = self.mean_pooling(b_h, mask2, dim=1)
b_max, _ = self.max_pooling(b_h, mask2, dim=1)

out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H]
logits = torch.tanh(self.classifier(out))

if target is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, target)

return {Const.LOSS: loss, Const.OUTPUT: logits}
else:
return {Const.OUTPUT: logits}

def predict(self, **kwargs):
return self.forward(**kwargs)

# input [batch_size, len , hidden]
# mask [batch_size, len] (111...00)
@staticmethod
def mean_pooling(input, mask, dim=1):
masks = mask.view(mask.size(0), mask.size(1), -1).float()
return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1)

@staticmethod
def max_pooling(input, mask, dim=1):
my_inf = 10e12
masks = mask.view(mask.size(0), mask.size(1), -1)
masks = masks.expand(-1, -1, input.size(2)).float()
return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim)


class EmbedDropout(nn.Dropout):

def forward(self, sequences_batch):
ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1])
dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False)
return dropout_mask.unsqueeze(1) * sequences_batch


class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, dropout_rate=0.3):
super(BiRNN, self).__init__()
self.dropout_rate = dropout_rate
self.rnn = nn.LSTM(input_size, hidden_size,
num_layers=1,
bidirectional=True,
batch_first=True)

def forward(self, x, x_mask):
# Sort x
lengths = x_mask.data.eq(1).long().sum(1).squeeze()
_, idx_sort = torch.sort(lengths, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0)
lengths = list(lengths[idx_sort])

x = x.index_select(0, idx_sort)
# Pack it up
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
# Apply dropout to input
if self.dropout_rate > 0:
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training)
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes)
output = self.rnn(rnn_input)[0]
# Unpack everything
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0]
output = output.index_select(0, idx_unsort)
if output.size(1) != x_mask.size(1):
padding = torch.zeros(output.size(0),
x_mask.size(1) - output.size(1),
output.size(2)).type(output.data.type())
output = torch.cat([output, padding], 1)
return output


def masked_softmax(tensor, mask):
tensor_shape = tensor.size()
reshaped_tensor = tensor.view(-1, tensor_shape[-1])

# Reshape the mask so it matches the size of the input tensor.
while mask.dim() < tensor.dim():
mask = mask.unsqueeze(1)
mask = mask.expand_as(tensor).contiguous().float()
reshaped_mask = mask.view(-1, mask.size()[-1])
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1)
result = result * reshaped_mask
# 1e-13 is added to avoid divisions by zero.
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result.view(*tensor_shape)


def weighted_sum(tensor, weights, mask):
w_sum = weights.bmm(tensor)
while mask.dim() < w_sum.dim():
mask = mask.unsqueeze(1)
mask = mask.transpose(-1, -2)
mask = mask.expand_as(w_sum).contiguous().float()
return w_sum * mask


class SoftmaxAttention(nn.Module):

def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask):
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1)
.contiguous())

prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2)
.contiguous(),
premise_mask)

attended_premises = weighted_sum(hypothesis_batch,
prem_hyp_attn,
premise_mask)
attended_hypotheses = weighted_sum(premise_batch,
hyp_prem_attn,
hypothesis_mask)

return attended_premises, attended_hypotheses

+ 0
- 88
reproduction/matching/snli.py View File

@@ -1,88 +0,0 @@
import os

import torch

from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric

from reproduction.matching.data.SNLIDataLoader import SNLILoader
from legacy.component.bert_tokenizer import BertTokenizer
from reproduction.matching.model.bert import BertForNLI


def preprocess_data(data: DataSet, bert_dir):
"""
preprocess data set to bert-need data set.
:param data:
:param bert_dir:
:return:
"""
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt'))

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

for i in range(2):
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])),
new_field_name=Const.INPUTS(i))
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment'])
target_vocab.build_vocab()
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET)
data.set_target(Const.TARGET)

return data


bert_dirs = 'path/to/bert/dir'

# load raw data set
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl')
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl')
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl')

print('successfully load data sets!')

train_data = preprocess_data(train_data, bert_dirs)
dev_data = preprocess_data(dev_data, bert_dirs)
test_data = preprocess_data(test_data, bert_dirs)

model = BertForNLI(bert_dir=bert_dirs)

trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data,
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=test_data,
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



Loading…
Cancel
Save