Browse Source

Merge branch 'dev0.5.0' of github.com:fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
yh_cc 5 years ago
parent
commit
4d138ed7f8
9 changed files with 430 additions and 103 deletions
  1. +3
    -2
      fastNLP/io/__init__.py
  2. +173
    -1
      fastNLP/io/dataset_loader.py
  3. +8
    -1
      fastNLP/modules/encoder/__init__.py
  4. +1
    -1
      fastNLP/modules/encoder/_bert.py
  5. +2
    -2
      fastNLP/modules/encoder/bert.py
  6. +17
    -8
      fastNLP/modules/encoder/embedding.py
  7. +44
    -0
      reproduction/matching/matching.py
  8. +182
    -0
      reproduction/matching/model/esim.py
  9. +0
    -88
      reproduction/matching/snli.py

+ 3
- 2
fastNLP/io/__init__.py View File

@@ -16,6 +16,7 @@ __all__ = [
'CSVLoader',
'JsonLoader',
'ConllLoader',
'MatchingLoader',
'SNLILoader',
'SSTLoader',
'PeopleDailyCorpusLoader',
@@ -26,6 +27,6 @@ __all__ = [
]

from .embed_loader import EmbedLoader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \
PeopleDailyCorpusLoader, Conll2003Loader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader, ModelSaver

+ 173
- 1
fastNLP/io/dataset_loader.py View File

@@ -16,19 +16,24 @@ __all__ = [
'CSVLoader',
'JsonLoader',
'ConllLoader',
'MatchingLoader',
'SNLILoader',
'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
]

import os
from nltk import Tree
from typing import Union, Dict
from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader
from .base_loader import DataSetLoader, DataInfo
from .data_loader.sst import SSTLoader
from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer


class PeopleDailyCorpusLoader(DataSetLoader):
@@ -245,6 +250,173 @@ class JsonLoader(DataSetLoader):
return ds


class MatchingLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`

读取Matching数据集,根据数据集做预处理并返回DataInfo。

数据来源:
SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""

def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None):
super(MatchingLoader, self).__init__()
self.data_format = data_format.lower()
self.for_model = for_model.lower()
self.bert_dir = bert_dir

def _load(self, path: str) -> DataSet:
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo:
if isinstance(paths, str):
paths = {'train': paths}

data_set = {}
for n, p in paths.items():
if self.data_format == 'snli':
data = self._load_snli(p)
else:
raise RuntimeError(f'Your data format is {self.data_format}, '
f'Please choose data format from [snli]')

if self.for_model == 'esim':
data = self._for_esim(data)
elif self.for_model == 'bert':
data = self._for_bert(data, self.bert_dir)
else:
raise RuntimeError(f'Your model is {self.data_format}, '
f'Please choose from [esim, bert]')

if input_field is not None:
if isinstance(input_field, str):
data.set_input(input_field)
elif isinstance(input_field, list):
for field in input_field:
data.set_input(field)

data_set[n] = data
print(f'successfully load {n} set!')

if not hasattr(self, 'vocab'):
raise RuntimeError(f'There is NOT vocab attribute built!')
if not hasattr(self, 'label_vocab'):
raise RuntimeError(f'There is NOT label vocab attribute built!')

if self.for_model != 'bert':
from fastNLP.modules.encoder.embedding import ElmoEmbedding
embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2')

data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab},
embeddings={'elmo': embedding} if self.for_model != 'bert' else None,
datasets=data_set)

return data_info

@staticmethod
def _load_snli(path: str) -> DataSet:
"""
读取SNLI数据集

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
:param str path: 数据集路径
:return:
"""
raw_ds = JsonLoader(
fields={
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
)._load(path)
return raw_ds

def _for_esim(self, raw_ds: DataSet):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

if not hasattr(self, 'vocab'):
self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)])
if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET)

raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1))
raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET)
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1))

raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1))
raw_ds.set_target(Const.TARGET)

return raw_ds

def _for_bert(self, raw_ds: DataSet, bert_dir: str):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

tokenizer = BertTokenizer.from_pretrained(bert_dir)

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

if not hasattr(self, 'vocab'):
self.vocab = vocab
else:
for w, idx in self.vocab:
if vocab[w] != idx:
raise AttributeError(f"{self.__class__.__name__} has ")

for i in range(2):
raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i))
raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None)
self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET)
raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1))
raw_ds.set_target(Const.TARGET)

return raw_ds


class SNLILoader(JsonLoader):
"""
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`


+ 8
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -7,6 +7,12 @@ __all__ = [
"ConvMaxpool",
"Embedding",
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
"LSTM",
@@ -22,7 +28,8 @@ from ._bert import BertModel
from .bert import BertWordPieceEncoder
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \
StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding
from .lstm import LSTM
from .star_transformer import StarTransformer
from .transformer import TransformerEncoder


+ 1
- 1
fastNLP/modules/encoder/_bert.py View File

@@ -7,7 +7,7 @@



from ... import Vocabulary
from ...core.vocabulary import Vocabulary
import collections

import unicodedata


+ 2
- 2
fastNLP/modules/encoder/bert.py View File

@@ -2,9 +2,9 @@
import os
from torch import nn
import torch
from ...core import Vocabulary
from ...core.vocabulary import Vocabulary
from ...io.file_utils import _get_base_url, cached_path
from ._bert import _WordPieceBertModel
from ._bert import _WordPieceBertModel, BertModel


class BertWordPieceEncoder(nn.Module):


+ 17
- 8
fastNLP/modules/encoder/embedding.py View File

@@ -1,10 +1,16 @@
__all__ = [
"Embedding"
"Embedding",
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
]
import torch.nn as nn
from ..utils import get_embeddings
from .lstm import LSTM
from ... import Vocabulary
from ...core.vocabulary import Vocabulary
from abc import abstractmethod
import torch
from ...io import EmbedLoader
@@ -15,7 +21,9 @@ from ...io.file_utils import cached_path, _get_base_url
from ._bert import _WordBertModel
from typing import List

from ... import DataSet, DataSetIter, SequentialSampler
from ...core.dataset import DataSet
from ...core.batch import DataSetIter
from ...core.sampler import SequentialSampler
from ...core.utils import _move_model_to_device, _get_model_device


@@ -144,6 +152,8 @@ class StaticEmbedding(TokenEmbedding):

Example::

>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50')


:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding
@@ -303,8 +313,7 @@ class ElmoEmbedding(ContextualEmbedding):

Example::

>>>
>>>
>>> embedding = ElmoEmbedding(vocab, model_dir_or_name='en', layers='2', requires_grad=True)

:param vocab: 词表
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称,
@@ -395,7 +404,7 @@ class BertEmbedding(ContextualEmbedding):

Example::

>>>
>>> embedding = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1')


:param fastNLP.Vocabulary vocab: 词表
@@ -505,7 +514,7 @@ class CNNCharEmbedding(TokenEmbedding):

Example::

>>>
>>> cnn_char_embed = CNNCharEmbedding(vocab)


:param vocab: 词表
@@ -641,7 +650,7 @@ class LSTMCharEmbedding(TokenEmbedding):

Example::

>>>
>>> lstm_char_embed = LSTMCharEmbedding(vocab)

:param vocab: 词表
:param embed_size: embedding的大小。默认值为50.


+ 44
- 0
reproduction/matching/matching.py View File

@@ -0,0 +1,44 @@
import os

import torch

from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const

from fastNLP.io.dataset_loader import MatchingLoader

from reproduction.matching.model.bert import BertForNLI
from reproduction.matching.model.esim import ESIMModel


bert_dirs = 'path/to/bert/dir'

# load data set
# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(...
data_info = MatchingLoader(data_format='snli', for_model='esim').process(
{'train': './data/snli/snli_1.0_train.jsonl',
'dev': './data/snli/snli_1.0_dev.jsonl',
'test': './data/snli/snli_1.0_test.jsonl'},
input_field=[Const.TARGET]
)

# model = BertForNLI(bert_dir=bert_dirs)
model = ESIMModel(data_info.embeddings['elmo'],)

trainer = Trainer(train_data=data_info.datasets['train'], model=model,
optimizer=Adam(lr=1e-4, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1,
dev_data=data_info.datasets['dev'],
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=data_info.datasets['test'],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



+ 182
- 0
reproduction/matching/model/esim.py View File

@@ -0,0 +1,182 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import CrossEntropyLoss

from fastNLP.models import BaseModel
from fastNLP.modules.encoder.embedding import TokenEmbedding
from fastNLP.modules.encoder.lstm import LSTM
from fastNLP.core.const import Const
from fastNLP.core.utils import seq_len_to_mask


class ESIMModel(BaseModel):
def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3,
dropout_embed=0.1):
super(ESIMModel, self).__init__()

self.embedding = init_embedding
self.dropout_embed = EmbedDropout(p=dropout_embed)
if hidden_size is None:
hidden_size = self.embedding.embed_size
self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate)
# self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True)

self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate),
nn.Linear(8 * hidden_size, hidden_size),
nn.ReLU())
nn.init.xavier_uniform_(self.interfere[1].weight.data)
self.bi_attention = SoftmaxAttention()

self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate)
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True)

self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate),
nn.Linear(8 * hidden_size, hidden_size),
nn.Tanh(),
nn.Dropout(p=dropout_rate),
nn.Linear(hidden_size, num_labels))
nn.init.xavier_uniform_(self.classifier[1].weight.data)
nn.init.xavier_uniform_(self.classifier[4].weight.data)

def forward(self, words1, words2, seq_len1, seq_len2, target=None):
mask1 = seq_len_to_mask(seq_len1)
mask2 = seq_len_to_mask(seq_len2)
a0 = self.embedding(words1) # B * len * emb_dim
b0 = self.embedding(words2)
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0)
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H]
b = self.rnn(b0, mask2.byte())

ai, bi = self.bi_attention(a, mask1, b, mask2)

a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H]
b_ = torch.cat((b, bi, b - bi, b * bi), dim=2)
a_f = self.interfere(a_)
b_f = self.interfere(b_)

a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H]
b_h = self.rnn_high(b_f, mask2.byte())

a_avg = self.mean_pooling(a_h, mask1, dim=1)
a_max, _ = self.max_pooling(a_h, mask1, dim=1)
b_avg = self.mean_pooling(b_h, mask2, dim=1)
b_max, _ = self.max_pooling(b_h, mask2, dim=1)

out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H]
logits = torch.tanh(self.classifier(out))

if target is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, target)

return {Const.LOSS: loss, Const.OUTPUT: logits}
else:
return {Const.OUTPUT: logits}

def predict(self, **kwargs):
return self.forward(**kwargs)

# input [batch_size, len , hidden]
# mask [batch_size, len] (111...00)
@staticmethod
def mean_pooling(input, mask, dim=1):
masks = mask.view(mask.size(0), mask.size(1), -1).float()
return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1)

@staticmethod
def max_pooling(input, mask, dim=1):
my_inf = 10e12
masks = mask.view(mask.size(0), mask.size(1), -1)
masks = masks.expand(-1, -1, input.size(2)).float()
return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim)


class EmbedDropout(nn.Dropout):

def forward(self, sequences_batch):
ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1])
dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False)
return dropout_mask.unsqueeze(1) * sequences_batch


class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, dropout_rate=0.3):
super(BiRNN, self).__init__()
self.dropout_rate = dropout_rate
self.rnn = nn.LSTM(input_size, hidden_size,
num_layers=1,
bidirectional=True,
batch_first=True)

def forward(self, x, x_mask):
# Sort x
lengths = x_mask.data.eq(1).long().sum(1).squeeze()
_, idx_sort = torch.sort(lengths, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0)
lengths = list(lengths[idx_sort])

x = x.index_select(0, idx_sort)
# Pack it up
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
# Apply dropout to input
if self.dropout_rate > 0:
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training)
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes)
output = self.rnn(rnn_input)[0]
# Unpack everything
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0]
output = output.index_select(0, idx_unsort)
if output.size(1) != x_mask.size(1):
padding = torch.zeros(output.size(0),
x_mask.size(1) - output.size(1),
output.size(2)).type(output.data.type())
output = torch.cat([output, padding], 1)
return output


def masked_softmax(tensor, mask):
tensor_shape = tensor.size()
reshaped_tensor = tensor.view(-1, tensor_shape[-1])

# Reshape the mask so it matches the size of the input tensor.
while mask.dim() < tensor.dim():
mask = mask.unsqueeze(1)
mask = mask.expand_as(tensor).contiguous().float()
reshaped_mask = mask.view(-1, mask.size()[-1])
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1)
result = result * reshaped_mask
# 1e-13 is added to avoid divisions by zero.
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result.view(*tensor_shape)


def weighted_sum(tensor, weights, mask):
w_sum = weights.bmm(tensor)
while mask.dim() < w_sum.dim():
mask = mask.unsqueeze(1)
mask = mask.transpose(-1, -2)
mask = mask.expand_as(w_sum).contiguous().float()
return w_sum * mask


class SoftmaxAttention(nn.Module):

def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask):
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1)
.contiguous())

prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2)
.contiguous(),
premise_mask)

attended_premises = weighted_sum(hypothesis_batch,
prem_hyp_attn,
premise_mask)
attended_hypotheses = weighted_sum(premise_batch,
hyp_prem_attn,
hypothesis_mask)

return attended_premises, attended_hypotheses

+ 0
- 88
reproduction/matching/snli.py View File

@@ -1,88 +0,0 @@
import os

import torch

from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric

from reproduction.matching.data.SNLIDataLoader import SNLILoader
from legacy.component.bert_tokenizer import BertTokenizer
from reproduction.matching.model.bert import BertForNLI


def preprocess_data(data: DataSet, bert_dir):
"""
preprocess data set to bert-need data set.
:param data:
:param bert_dir:
:return:
"""
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt'))

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

for i in range(2):
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])),
new_field_name=Const.INPUTS(i))
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment'])
target_vocab.build_vocab()
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET)
data.set_target(Const.TARGET)

return data


bert_dirs = 'path/to/bert/dir'

# load raw data set
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl')
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl')
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl')

print('successfully load data sets!')

train_data = preprocess_data(train_data, bert_dirs)
dev_data = preprocess_data(dev_data, bert_dirs)
test_data = preprocess_data(test_data, bert_dirs)

model = BertForNLI(bert_dir=bert_dirs)

trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data,
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=test_data,
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



Loading…
Cancel
Save