Browse Source

update matching.py

tags/v0.4.10
xuyige 5 years ago
parent
commit
39388567ad
7 changed files with 229 additions and 96 deletions
  1. +3
    -2
      fastNLP/io/__init__.py
  2. +162
    -1
      fastNLP/io/dataset_loader.py
  3. +8
    -1
      fastNLP/modules/encoder/__init__.py
  4. +1
    -1
      fastNLP/modules/encoder/_bert.py
  5. +11
    -3
      fastNLP/modules/encoder/embedding.py
  6. +44
    -0
      reproduction/matching/matching.py
  7. +0
    -88
      reproduction/matching/snli.py

+ 3
- 2
fastNLP/io/__init__.py View File

@@ -16,6 +16,7 @@ __all__ = [
'CSVLoader', 'CSVLoader',
'JsonLoader', 'JsonLoader',
'ConllLoader', 'ConllLoader',
'MatchingLoader',
'SNLILoader', 'SNLILoader',
'SSTLoader', 'SSTLoader',
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
@@ -26,6 +27,6 @@ __all__ = [
] ]


from .embed_loader import EmbedLoader from .embed_loader import EmbedLoader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \
PeopleDailyCorpusLoader, Conll2003Loader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader, ModelSaver from .model_io import ModelLoader, ModelSaver

+ 162
- 1
fastNLP/io/dataset_loader.py View File

@@ -16,19 +16,24 @@ __all__ = [
'CSVLoader', 'CSVLoader',
'JsonLoader', 'JsonLoader',
'ConllLoader', 'ConllLoader',
'MatchingLoader',
'SNLILoader', 'SNLILoader',
'SSTLoader', 'SSTLoader',
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
'Conll2003Loader', 'Conll2003Loader',
] ]


import os
from nltk import Tree from nltk import Tree
from typing import Union, Dict
from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.instance import Instance from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader
from .base_loader import DataSetLoader, DataInfo
from .data_loader.sst import SSTLoader from .data_loader.sst import SSTLoader
from ..core.const import Const from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer




class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
@@ -244,6 +249,162 @@ class JsonLoader(DataSetLoader):
return ds return ds




class MatchingLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`

读取Matching数据集,根据数据集做预处理并返回DataInfo。

数据来源:
SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""

def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None):
super(MatchingLoader, self).__init__()
self.data_format = data_format.lower()
self.for_model = for_model.lower()
self.bert_dir = bert_dir

def _load(self, path: str) -> DataSet:
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo:
if isinstance(paths, str):
paths = {'train': paths}

data_set = {}
for n, p in paths.items():
if self.data_format == 'snli':
data = self._load_snli(p)
else:
raise RuntimeError(f'Your data format is {self.data_format}, '
f'Please choose data format from [snli]')

if self.for_model == 'esim':
data = self._for_esim(data)
elif self.for_model == 'bert':
data = self._for_bert(data, self.bert_dir)
else:
raise RuntimeError(f'Your model is {self.data_format}, '
f'Please choose from [esim, bert]')

data_set[n] = data
print(f'successfully load {n} set!')

if not hasattr(self, 'vocab'):
raise RuntimeError(f'There is NOT vocab attribute built!')
if not hasattr(self, 'label_vocab'):
raise RuntimeError(f'There is NOT label vocab attribute built!')

if self.for_model != 'bert':
from fastNLP.modules.encoder.embedding import StaticEmbedding
embedding = StaticEmbedding(self.vocab, model_dir_or_name='en')

data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab},
embeddings={'glove': embedding} if self.for_model != 'bert' else None,
datasets=data_set)

return data_info

@staticmethod
def _load_snli(path: str) -> DataSet:
"""
读取SNLI数据集

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
:param str path: 数据集路径
:return:
"""
raw_ds = JsonLoader(
fields={
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
)._load(path)
return raw_ds

def _for_esim(self, raw_ds: DataSet):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

if not hasattr(self, 'vocab'):
self.vocab = Vocabulary().from_dataset(raw_ds, [Const.INPUTS(0), Const.INPUTS(1)])
if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET)

raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1))
raw_ds.apply(lambda ins: self.label_vocab.to_index(Const.TARGET), new_field_name=Const.TARGET)

raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1))
raw_ds.set_target(Const.TARGET)

return raw_ds

def _for_bert(self, raw_ds: DataSet, bert_dir: str):
if self.data_format == 'snli' or self.data_format == 'mnli':
def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
raw_ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
raw_ds.drop(lambda x: x[Const.TARGET] == '-')

tokenizer = BertTokenizer.from_pretrained(bert_dir)

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

if not hasattr(self, 'vocab'):
self.vocab = vocab
else:
for w, idx in self.vocab:
if vocab[w] != idx:
raise AttributeError(f"{self.__class__.__name__} has ")

for i in range(2):
raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i))
raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

if not hasattr(self, 'label_vocab'):
self.label_vocab = Vocabulary(padding=None, unknown=None)
self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET)
raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1))
raw_ds.set_target(Const.TARGET)


class SNLILoader(JsonLoader): class SNLILoader(JsonLoader):
""" """
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`


+ 8
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -7,6 +7,12 @@ __all__ = [
"ConvMaxpool", "ConvMaxpool",
"Embedding", "Embedding",
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
"LSTM", "LSTM",
@@ -21,7 +27,8 @@ __all__ = [
from .bert import BertModel from .bert import BertModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \
StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding
from .lstm import LSTM from .lstm import LSTM
from .star_transformer import StarTransformer from .star_transformer import StarTransformer
from .transformer import TransformerEncoder from .transformer import TransformerEncoder


+ 1
- 1
fastNLP/modules/encoder/_bert.py View File

@@ -9,7 +9,7 @@
import torch import torch
from torch import nn from torch import nn


from ... import Vocabulary
from ...core.vocabulary import Vocabulary
import collections import collections


import os import os


+ 11
- 3
fastNLP/modules/encoder/embedding.py View File

@@ -1,10 +1,16 @@
__all__ = [ __all__ = [
"Embedding"
"Embedding",
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
] ]
import torch.nn as nn import torch.nn as nn
from ..utils import get_embeddings from ..utils import get_embeddings
from .lstm import LSTM from .lstm import LSTM
from ... import Vocabulary
from ...core.vocabulary import Vocabulary
from abc import abstractmethod from abc import abstractmethod
import torch import torch
from ...io import EmbedLoader from ...io import EmbedLoader
@@ -15,7 +21,9 @@ from ...io.file_utils import cached_path, _get_base_url
from ._bert import _WordBertModel from ._bert import _WordBertModel
from typing import List from typing import List


from ... import DataSet, DataSetIter, SequentialSampler
from ...core.dataset import DataSet
from ...core.batch import DataSetIter
from ...core.sampler import SequentialSampler
from ...core.utils import _move_model_to_device, _get_model_device from ...core.utils import _move_model_to_device, _get_model_device






+ 44
- 0
reproduction/matching/matching.py View File

@@ -0,0 +1,44 @@
import os

import torch

from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric

from fastNLP.io.dataset_loader import MatchingLoader

from reproduction.matching.model.bert import BertForNLI


# bert_dirs = 'path/to/bert/dir'
bert_dirs = '/remote-home/ygxu/BERT/BERT_English_uncased_L-12_H-768_A_12'

# load data set
data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(
{#'train': './data/snli/snli_1.0_train.jsonl',
'dev': './data/snli/snli_1.0_dev.jsonl',
'test': './data/snli/snli_1.0_test.jsonl'}
)

print('successfully load data sets!')


model = BertForNLI(bert_dir=bert_dirs)

trainer = Trainer(train_data=data_info.datasets['dev'], model=model,
optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1,
dev_data=data_info.datasets['dev'],
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=data_info.datasets['test'],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



+ 0
- 88
reproduction/matching/snli.py View File

@@ -1,88 +0,0 @@
import os

import torch

from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric

from reproduction.matching.data.SNLIDataLoader import SNLILoader
from legacy.component.bert_tokenizer import BertTokenizer
from reproduction.matching.model.bert import BertForNLI


def preprocess_data(data: DataSet, bert_dir):
"""
preprocess data set to bert-need data set.
:param data:
:param bert_dir:
:return:
"""
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt'))

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

for i in range(2):
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])),
new_field_name=Const.INPUTS(i))
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment'])
target_vocab.build_vocab()
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET)
data.set_target(Const.TARGET)

return data


bert_dirs = 'path/to/bert/dir'

# load raw data set
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl')
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl')
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl')

print('successfully load data sets!')

train_data = preprocess_data(train_data, bert_dirs)
dev_data = preprocess_data(dev_data, bert_dirs)
test_data = preprocess_data(test_data, bert_dirs)

model = BertForNLI(bert_dir=bert_dirs)

trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data,
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=test_data,
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



Loading…
Cancel
Save