@@ -16,6 +16,7 @@ __all__ = [ | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'MatchingLoader', | |||||
'SNLILoader', | 'SNLILoader', | ||||
'SSTLoader', | 'SSTLoader', | ||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
@@ -26,6 +27,6 @@ __all__ = [ | |||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||||
PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\ | |||||
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .model_io import ModelLoader, ModelSaver | from .model_io import ModelLoader, ModelSaver |
@@ -16,19 +16,24 @@ __all__ = [ | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'MatchingLoader', | |||||
'SNLILoader', | 'SNLILoader', | ||||
'SSTLoader', | 'SSTLoader', | ||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
] | ] | ||||
import os | |||||
from nltk import Tree | from nltk import Tree | ||||
from typing import Union, Dict | |||||
from ..core.vocabulary import Vocabulary | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.instance import Instance | from ..core.instance import Instance | ||||
from .file_reader import _read_csv, _read_json, _read_conll | from .file_reader import _read_csv, _read_json, _read_conll | ||||
from .base_loader import DataSetLoader | |||||
from .base_loader import DataSetLoader, DataInfo | |||||
from .data_loader.sst import SSTLoader | from .data_loader.sst import SSTLoader | ||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..modules.encoder._bert import BertTokenizer | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
@@ -244,6 +249,162 @@ class JsonLoader(DataSetLoader): | |||||
return ds | return ds | ||||
class MatchingLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||||
读取Matching数据集,根据数据集做预处理并返回DataInfo。 | |||||
数据来源: | |||||
SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None): | |||||
super(MatchingLoader, self).__init__() | |||||
self.data_format = data_format.lower() | |||||
self.for_model = for_model.lower() | |||||
self.bert_dir = bert_dir | |||||
def _load(self, path: str) -> DataSet: | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: | |||||
if isinstance(paths, str): | |||||
paths = {'train': paths} | |||||
data_set = {} | |||||
for n, p in paths.items(): | |||||
if self.data_format == 'snli': | |||||
data = self._load_snli(p) | |||||
else: | |||||
raise RuntimeError(f'Your data format is {self.data_format}, ' | |||||
f'Please choose data format from [snli]') | |||||
if self.for_model == 'esim': | |||||
data = self._for_esim(data) | |||||
elif self.for_model == 'bert': | |||||
data = self._for_bert(data, self.bert_dir) | |||||
else: | |||||
raise RuntimeError(f'Your model is {self.data_format}, ' | |||||
f'Please choose from [esim, bert]') | |||||
data_set[n] = data | |||||
print(f'successfully load {n} set!') | |||||
if not hasattr(self, 'vocab'): | |||||
raise RuntimeError(f'There is NOT vocab attribute built!') | |||||
if not hasattr(self, 'label_vocab'): | |||||
raise RuntimeError(f'There is NOT label vocab attribute built!') | |||||
if self.for_model != 'bert': | |||||
from fastNLP.modules.encoder.embedding import StaticEmbedding | |||||
embedding = StaticEmbedding(self.vocab, model_dir_or_name='en') | |||||
data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, | |||||
embeddings={'glove': embedding} if self.for_model != 'bert' else None, | |||||
datasets=data_set) | |||||
return data_info | |||||
@staticmethod | |||||
def _load_snli(path: str) -> DataSet: | |||||
""" | |||||
读取SNLI数据集 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
:param str path: 数据集路径 | |||||
:return: | |||||
""" | |||||
raw_ds = JsonLoader( | |||||
fields={ | |||||
'sentence1_parse': Const.INPUTS(0), | |||||
'sentence2_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
)._load(path) | |||||
return raw_ds | |||||
def _for_esim(self, raw_ds: DataSet): | |||||
if self.data_format == 'snli' or self.data_format == 'mnli': | |||||
def parse_tree(x): | |||||
t = Tree.fromstring(x) | |||||
return t.leaves() | |||||
raw_ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
raw_ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
raw_ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
if not hasattr(self, 'vocab'): | |||||
self.vocab = Vocabulary().from_dataset(raw_ds, [Const.INPUTS(0), Const.INPUTS(1)]) | |||||
if not hasattr(self, 'label_vocab'): | |||||
self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) | |||||
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) | |||||
raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) | |||||
raw_ds.apply(lambda ins: self.label_vocab.to_index(Const.TARGET), new_field_name=Const.TARGET) | |||||
raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||||
raw_ds.set_target(Const.TARGET) | |||||
return raw_ds | |||||
def _for_bert(self, raw_ds: DataSet, bert_dir: str): | |||||
if self.data_format == 'snli' or self.data_format == 'mnli': | |||||
def parse_tree(x): | |||||
t = Tree.fromstring(x) | |||||
return t.leaves() | |||||
raw_ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
raw_ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
raw_ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
tokenizer = BertTokenizer.from_pretrained(bert_dir) | |||||
vocab = Vocabulary(padding=None, unknown=None) | |||||
with open(os.path.join(bert_dir, 'vocab.txt')) as f: | |||||
lines = f.readlines() | |||||
vocab_list = [] | |||||
for line in lines: | |||||
vocab_list.append(line.strip()) | |||||
vocab.add_word_lst(vocab_list) | |||||
vocab.build_vocab() | |||||
vocab.padding = '[PAD]' | |||||
vocab.unknown = '[UNK]' | |||||
if not hasattr(self, 'vocab'): | |||||
self.vocab = vocab | |||||
else: | |||||
for w, idx in self.vocab: | |||||
if vocab[w] != idx: | |||||
raise AttributeError(f"{self.__class__.__name__} has ") | |||||
for i in range(2): | |||||
raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i)) | |||||
raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], | |||||
new_field_name=Const.INPUT) | |||||
raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0)) | |||||
raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) | |||||
max_len = 512 | |||||
raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) | |||||
raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) | |||||
raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) | |||||
raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) | |||||
if not hasattr(self, 'label_vocab'): | |||||
self.label_vocab = Vocabulary(padding=None, unknown=None) | |||||
self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET) | |||||
raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) | |||||
raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) | |||||
raw_ds.set_target(Const.TARGET) | |||||
class SNLILoader(JsonLoader): | class SNLILoader(JsonLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | ||||
@@ -7,6 +7,12 @@ __all__ = [ | |||||
"ConvMaxpool", | "ConvMaxpool", | ||||
"Embedding", | "Embedding", | ||||
"StaticEmbedding", | |||||
"ElmoEmbedding", | |||||
"BertEmbedding", | |||||
"StackEmbedding", | |||||
"LSTMCharEmbedding", | |||||
"CNNCharEmbedding", | |||||
"LSTM", | "LSTM", | ||||
@@ -21,7 +27,8 @@ __all__ = [ | |||||
from .bert import BertModel | from .bert import BertModel | ||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | ||||
from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
from .embedding import Embedding | |||||
from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ | |||||
StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding | |||||
from .lstm import LSTM | from .lstm import LSTM | ||||
from .star_transformer import StarTransformer | from .star_transformer import StarTransformer | ||||
from .transformer import TransformerEncoder | from .transformer import TransformerEncoder | ||||
@@ -9,7 +9,7 @@ | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from ... import Vocabulary | |||||
from ...core.vocabulary import Vocabulary | |||||
import collections | import collections | ||||
import os | import os | ||||
@@ -1,10 +1,16 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"Embedding" | |||||
"Embedding", | |||||
"StaticEmbedding", | |||||
"ElmoEmbedding", | |||||
"BertEmbedding", | |||||
"StackEmbedding", | |||||
"LSTMCharEmbedding", | |||||
"CNNCharEmbedding", | |||||
] | ] | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from ..utils import get_embeddings | from ..utils import get_embeddings | ||||
from .lstm import LSTM | from .lstm import LSTM | ||||
from ... import Vocabulary | |||||
from ...core.vocabulary import Vocabulary | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
import torch | import torch | ||||
from ...io import EmbedLoader | from ...io import EmbedLoader | ||||
@@ -15,7 +21,9 @@ from ...io.file_utils import cached_path, _get_base_url | |||||
from ._bert import _WordBertModel | from ._bert import _WordBertModel | ||||
from typing import List | from typing import List | ||||
from ... import DataSet, DataSetIter, SequentialSampler | |||||
from ...core.dataset import DataSet | |||||
from ...core.batch import DataSetIter | |||||
from ...core.sampler import SequentialSampler | |||||
from ...core.utils import _move_model_to_device, _get_model_device | from ...core.utils import _move_model_to_device, _get_model_device | ||||
@@ -0,0 +1,44 @@ | |||||
import os | |||||
import torch | |||||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric | |||||
from fastNLP.io.dataset_loader import MatchingLoader | |||||
from reproduction.matching.model.bert import BertForNLI | |||||
# bert_dirs = 'path/to/bert/dir' | |||||
bert_dirs = '/remote-home/ygxu/BERT/BERT_English_uncased_L-12_H-768_A_12' | |||||
# load data set | |||||
data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process( | |||||
{#'train': './data/snli/snli_1.0_train.jsonl', | |||||
'dev': './data/snli/snli_1.0_dev.jsonl', | |||||
'test': './data/snli/snli_1.0_test.jsonl'} | |||||
) | |||||
print('successfully load data sets!') | |||||
model = BertForNLI(bert_dir=bert_dirs) | |||||
trainer = Trainer(train_data=data_info.datasets['dev'], model=model, | |||||
optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, | |||||
dev_data=data_info.datasets['dev'], | |||||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1) | |||||
trainer.train(load_best_model=True) | |||||
tester = Tester( | |||||
data=data_info.datasets['test'], | |||||
model=model, | |||||
metrics=AccuracyMetric(), | |||||
batch_size=torch.cuda.device_count() * 12, | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
) | |||||
tester.test() | |||||
@@ -1,88 +0,0 @@ | |||||
import os | |||||
import torch | |||||
from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric | |||||
from reproduction.matching.data.SNLIDataLoader import SNLILoader | |||||
from legacy.component.bert_tokenizer import BertTokenizer | |||||
from reproduction.matching.model.bert import BertForNLI | |||||
def preprocess_data(data: DataSet, bert_dir): | |||||
""" | |||||
preprocess data set to bert-need data set. | |||||
:param data: | |||||
:param bert_dir: | |||||
:return: | |||||
""" | |||||
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt')) | |||||
vocab = Vocabulary(padding=None, unknown=None) | |||||
with open(os.path.join(bert_dir, 'vocab.txt')) as f: | |||||
lines = f.readlines() | |||||
vocab_list = [] | |||||
for line in lines: | |||||
vocab_list.append(line.strip()) | |||||
vocab.add_word_lst(vocab_list) | |||||
vocab.build_vocab() | |||||
vocab.padding = '[PAD]' | |||||
vocab.unknown = '[UNK]' | |||||
for i in range(2): | |||||
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), | |||||
new_field_name=Const.INPUTS(i)) | |||||
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], | |||||
new_field_name=Const.INPUT) | |||||
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0)) | |||||
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) | |||||
max_len = 512 | |||||
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) | |||||
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) | |||||
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) | |||||
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment']) | |||||
target_vocab.build_vocab() | |||||
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) | |||||
data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET) | |||||
data.set_target(Const.TARGET) | |||||
return data | |||||
bert_dirs = 'path/to/bert/dir' | |||||
# load raw data set | |||||
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl') | |||||
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl') | |||||
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl') | |||||
print('successfully load data sets!') | |||||
train_data = preprocess_data(train_data, bert_dirs) | |||||
dev_data = preprocess_data(dev_data, bert_dirs) | |||||
test_data = preprocess_data(test_data, bert_dirs) | |||||
model = BertForNLI(bert_dir=bert_dirs) | |||||
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data, | |||||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1) | |||||
trainer.train(load_best_model=True) | |||||
tester = Tester( | |||||
data=test_data, | |||||
model=model, | |||||
metrics=AccuracyMetric(), | |||||
batch_size=torch.cuda.device_count() * 12, | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
) | |||||
tester.test() | |||||