Browse Source

update matching.py

tags/v0.4.10
xuyige 5 years ago
parent
commit
39388567ad
7 changed files with 229 additions and 96 deletions
  1. +3
    -2
      fastNLP/io/__init__.py
  2. +162
    -1
      fastNLP/io/dataset_loader.py
  3. +8
    -1
      fastNLP/modules/encoder/__init__.py
  4. +1
    -1
      fastNLP/modules/encoder/_bert.py
  5. +11
    -3
      fastNLP/modules/encoder/embedding.py
  6. +44
    -0
      reproduction/matching/matching.py
  7. +0
    -88
      reproduction/matching/snli.py

+ 3
- 2
fastNLP/io/__init__.py View File

@@ -16,6 +16,7 @@ __all__ = [
'CSVLoader',
'JsonLoader',
'ConllLoader',
'MatchingLoader',
'SNLILoader',
'SSTLoader',
'PeopleDailyCorpusLoader',
@@ -26,6 +27,6 @@ __all__ = [
]

from .embed_loader import EmbedLoader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \
PeopleDailyCorpusLoader, Conll2003Loader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader, ModelSaver

+ 162
- 1
fastNLP/io/dataset_loader.py View File

@@ -16,19 +16,24 @@ __all__ = [
'CSVLoader',
'JsonLoader',
'ConllLoader',
'MatchingLoader',
'SNLILoader',
'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
]

import os
from nltk import Tree
from typing import Union, Dict
from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader
from .base_loader import DataSetLoader, DataInfo
from .data_loader.sst import SSTLoader
from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer


class PeopleDailyCorpusLoader(DataSetLoader):
@@ -244,6 +249,162 @@ class JsonLoader(DataSetLoader):
return ds


class MatchingLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`

读取Matching数据集,根据数据集做预处理并返回DataInfo。

数据来源:
SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""

def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None):
super(MatchingLoader, self).__init__()
self.data_format = data_format.lower()
self.for_model = for_model.lower()
self.bert_dir = bert_dir

def _load(self, path: str) -> DataSet:
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
if isinstance(paths, str):
paths = {'train': paths}

data_set = {}
for n, p in paths.items():
if self.data_format == 'snli':
data = self._load_snli(p)
else:
raise RuntimeError(f'Your data format is {self.data_format}, '
f'Please choose data format from [snli]')

if self.for_model == 'esim':
data = self._for_esim(data)
elif self.for_model == 'bert':
data = self._for_bert(data, self.bert_dir)
else:
raise RuntimeError(f'Your model is {self.data_format}, '
f'Please choose from [esim, bert]')

data_set[n] = data
print(f'successfully load {n} set!')

if not hasattr(self, 'vocab'):
raise RuntimeError(f'There is NOT vocab attribute built!')
if not hasattr(self, 'label_vocab'):
raise RuntimeError(f'There is NOT label vocab attribute built!')

if self.for_model != 'bert':
from fastNLP.modules.encoder.embedding import StaticEmbedding
embedding = StaticEmbedding(self.vocab, model_dir_or_name='en')

data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab},
embeddings={'glove': embedding} if self.for_model != 'bert' else None,
datasets=data_set)

return data_info

@staticmethod
def _load_snli(path: str) -> DataSet:
"""
读取SNLI数据集

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
:param str path: 数据集路径
:return:
"""
raw_ds = JsonLoader(
fields={
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
)._load(path)
return raw_ds

def _for_esim(self, raw_ds: DataSet):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

if not hasattr(self, 'vocab'):
self.vocab = Vocabulary().from_dataset(raw_ds, [Const.INPUTS(0), Const.INPUTS(1)])
if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET)

raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1))
raw_ds.apply(lambda ins: self.label_vocab.to_index(Const.TARGET), new_field_name=Const.TARGET)

raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1))
raw_ds.set_target(Const.TARGET)

return raw_ds

def _for_bert(self, raw_ds: DataSet, bert_dir: str):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

tokenizer = BertTokenizer.from_pretrained(bert_dir)

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

if not hasattr(self, 'vocab'):
self.vocab = vocab
else:
for w, idx in self.vocab:
if vocab[w] != idx:
raise AttributeError(f"{self.__class__.__name__} has ")

for i in range(2):
raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i))
raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None)
self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET)
raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1))
raw_ds.set_target(Const.TARGET)


class SNLILoader(JsonLoader):
"""
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`


+ 8
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -7,6 +7,12 @@ __all__ = [
"ConvMaxpool",
"Embedding",
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
"LSTM",
@@ -21,7 +27,8 @@ __all__ = [
from .bert import BertModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \
StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding
from .lstm import LSTM
from .star_transformer import StarTransformer
from .transformer import TransformerEncoder


+ 1
- 1
fastNLP/modules/encoder/_bert.py View File

@@ -9,7 +9,7 @@
import torch
from torch import nn

from ... import Vocabulary
from ...core.vocabulary import Vocabulary
import collections

import os


+ 11
- 3
fastNLP/modules/encoder/embedding.py View File

@@ -1,10 +1,16 @@
__all__ = [
"Embedding"
"Embedding",
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
]
import torch.nn as nn
from ..utils import get_embeddings
from .lstm import LSTM
from ... import Vocabulary
from ...core.vocabulary import Vocabulary
from abc import abstractmethod
import torch
from ...io import EmbedLoader
@@ -15,7 +21,9 @@ from ...io.file_utils import cached_path, _get_base_url
from ._bert import _WordBertModel
from typing import List

from ... import DataSet, DataSetIter, SequentialSampler
from ...core.dataset import DataSet
from ...core.batch import DataSetIter
from ...core.sampler import SequentialSampler
from ...core.utils import _move_model_to_device, _get_model_device




+ 44
- 0
reproduction/matching/matching.py View File

@@ -0,0 +1,44 @@
import os

import torch

from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric

from fastNLP.io.dataset_loader import MatchingLoader

from reproduction.matching.model.bert import BertForNLI


# bert_dirs = 'path/to/bert/dir'
bert_dirs = '/remote-home/ygxu/BERT/BERT_English_uncased_L-12_H-768_A_12'

# load data set
data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(
{#'train': './data/snli/snli_1.0_train.jsonl',
'dev': './data/snli/snli_1.0_dev.jsonl',
'test': './data/snli/snli_1.0_test.jsonl'}
)

print('successfully load data sets!')


model = BertForNLI(bert_dir=bert_dirs)

trainer = Trainer(train_data=data_info.datasets['dev'], model=model,
optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1,
dev_data=data_info.datasets['dev'],
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=data_info.datasets['test'],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



+ 0
- 88
reproduction/matching/snli.py View File

@@ -1,88 +0,0 @@
import os

import torch

from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric

from reproduction.matching.data.SNLIDataLoader import SNLILoader
from legacy.component.bert_tokenizer import BertTokenizer
from reproduction.matching.model.bert import BertForNLI


def preprocess_data(data: DataSet, bert_dir):
"""
preprocess data set to bert-need data set.
:param data:
:param bert_dir:
:return:
"""
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt'))

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

for i in range(2):
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])),
new_field_name=Const.INPUTS(i))
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment'])
target_vocab.build_vocab()
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET)
data.set_target(Const.TARGET)

return data


bert_dirs = 'path/to/bert/dir'

# load raw data set
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl')
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl')
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl')

print('successfully load data sets!')

train_data = preprocess_data(train_data, bert_dirs)
dev_data = preprocess_data(dev_data, bert_dirs)
test_data = preprocess_data(test_data, bert_dirs)

model = BertForNLI(bert_dir=bert_dirs)

trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data,
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=test_data,
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



Loading…
Cancel
Save