diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index c8d6a441..83425ff7 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -16,6 +16,7 @@ __all__ = [ 'CSVLoader', 'JsonLoader', 'ConllLoader', + 'MatchingLoader', 'SNLILoader', 'SSTLoader', 'PeopleDailyCorpusLoader', @@ -26,6 +27,6 @@ __all__ = [ ] from .embed_loader import EmbedLoader -from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ - PeopleDailyCorpusLoader, Conll2003Loader +from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\ + SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader from .model_io import ModelLoader, ModelSaver diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index e366c6ea..0595ad46 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -16,19 +16,24 @@ __all__ = [ 'CSVLoader', 'JsonLoader', 'ConllLoader', + 'MatchingLoader', 'SNLILoader', 'SSTLoader', 'PeopleDailyCorpusLoader', 'Conll2003Loader', ] +import os from nltk import Tree +from typing import Union, Dict +from ..core.vocabulary import Vocabulary from ..core.dataset import DataSet from ..core.instance import Instance from .file_reader import _read_csv, _read_json, _read_conll -from .base_loader import DataSetLoader +from .base_loader import DataSetLoader, DataInfo from .data_loader.sst import SSTLoader from ..core.const import Const +from ..modules.encoder._bert import BertTokenizer class PeopleDailyCorpusLoader(DataSetLoader): @@ -244,6 +249,162 @@ class JsonLoader(DataSetLoader): return ds +class MatchingLoader(DataSetLoader): + """ + 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` + + 读取Matching数据集,根据数据集做预处理并返回DataInfo。 + + 数据来源: + SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + """ + + def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None): + super(MatchingLoader, self).__init__() + self.data_format = data_format.lower() + self.for_model = for_model.lower() + self.bert_dir = bert_dir + + def _load(self, path: str) -> DataSet: + raise NotImplementedError + + def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: + if isinstance(paths, str): + paths = {'train': paths} + + data_set = {} + for n, p in paths.items(): + if self.data_format == 'snli': + data = self._load_snli(p) + else: + raise RuntimeError(f'Your data format is {self.data_format}, ' + f'Please choose data format from [snli]') + + if self.for_model == 'esim': + data = self._for_esim(data) + elif self.for_model == 'bert': + data = self._for_bert(data, self.bert_dir) + else: + raise RuntimeError(f'Your model is {self.data_format}, ' + f'Please choose from [esim, bert]') + + data_set[n] = data + print(f'successfully load {n} set!') + + if not hasattr(self, 'vocab'): + raise RuntimeError(f'There is NOT vocab attribute built!') + if not hasattr(self, 'label_vocab'): + raise RuntimeError(f'There is NOT label vocab attribute built!') + + if self.for_model != 'bert': + from fastNLP.modules.encoder.embedding import StaticEmbedding + embedding = StaticEmbedding(self.vocab, model_dir_or_name='en') + + data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, + embeddings={'glove': embedding} if self.for_model != 'bert' else None, + datasets=data_set) + + return data_info + + @staticmethod + def _load_snli(path: str) -> DataSet: + """ + 读取SNLI数据集 + + 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + :param str path: 数据集路径 + :return: + """ + raw_ds = JsonLoader( + fields={ + 'sentence1_parse': Const.INPUTS(0), + 'sentence2_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + )._load(path) + return raw_ds + + def _for_esim(self, raw_ds: DataSet): + if self.data_format == 'snli' or self.data_format == 'mnli': + def parse_tree(x): + t = Tree.fromstring(x) + return t.leaves() + + raw_ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) + raw_ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + raw_ds.drop(lambda x: x[Const.TARGET] == '-') + + if not hasattr(self, 'vocab'): + self.vocab = Vocabulary().from_dataset(raw_ds, [Const.INPUTS(0), Const.INPUTS(1)]) + if not hasattr(self, 'label_vocab'): + self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) + + raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) + raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) + raw_ds.apply(lambda ins: self.label_vocab.to_index(Const.TARGET), new_field_name=Const.TARGET) + + raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1)) + raw_ds.set_target(Const.TARGET) + + return raw_ds + + def _for_bert(self, raw_ds: DataSet, bert_dir: str): + if self.data_format == 'snli' or self.data_format == 'mnli': + def parse_tree(x): + t = Tree.fromstring(x) + return t.leaves() + + raw_ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) + raw_ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + raw_ds.drop(lambda x: x[Const.TARGET] == '-') + + tokenizer = BertTokenizer.from_pretrained(bert_dir) + + vocab = Vocabulary(padding=None, unknown=None) + with open(os.path.join(bert_dir, 'vocab.txt')) as f: + lines = f.readlines() + vocab_list = [] + for line in lines: + vocab_list.append(line.strip()) + vocab.add_word_lst(vocab_list) + vocab.build_vocab() + vocab.padding = '[PAD]' + vocab.unknown = '[UNK]' + + if not hasattr(self, 'vocab'): + self.vocab = vocab + else: + for w, idx in self.vocab: + if vocab[w] != idx: + raise AttributeError(f"{self.__class__.__name__} has ") + + for i in range(2): + raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i)) + raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], + new_field_name=Const.INPUT) + raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), + new_field_name=Const.INPUT_LENS(0)) + raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) + + max_len = 512 + raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) + raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) + raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) + raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) + + if not hasattr(self, 'label_vocab'): + self.label_vocab = Vocabulary(padding=None, unknown=None) + self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET) + raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) + + raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) + raw_ds.set_target(Const.TARGET) + + class SNLILoader(JsonLoader): """ 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index bdc4cbf3..4be75f20 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -7,6 +7,12 @@ __all__ = [ "ConvMaxpool", "Embedding", + "StaticEmbedding", + "ElmoEmbedding", + "BertEmbedding", + "StackEmbedding", + "LSTMCharEmbedding", + "CNNCharEmbedding", "LSTM", @@ -21,7 +27,8 @@ __all__ = [ from .bert import BertModel from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .conv_maxpool import ConvMaxpool -from .embedding import Embedding +from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ + StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding from .lstm import LSTM from .star_transformer import StarTransformer from .transformer import TransformerEncoder diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py index fc62ea9c..1423f333 100644 --- a/fastNLP/modules/encoder/_bert.py +++ b/fastNLP/modules/encoder/_bert.py @@ -9,7 +9,7 @@ import torch from torch import nn -from ... import Vocabulary +from ...core.vocabulary import Vocabulary import collections import os diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 7279a372..f956aae7 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -1,10 +1,16 @@ __all__ = [ - "Embedding" + "Embedding", + "StaticEmbedding", + "ElmoEmbedding", + "BertEmbedding", + "StackEmbedding", + "LSTMCharEmbedding", + "CNNCharEmbedding", ] import torch.nn as nn from ..utils import get_embeddings from .lstm import LSTM -from ... import Vocabulary +from ...core.vocabulary import Vocabulary from abc import abstractmethod import torch from ...io import EmbedLoader @@ -15,7 +21,9 @@ from ...io.file_utils import cached_path, _get_base_url from ._bert import _WordBertModel from typing import List -from ... import DataSet, DataSetIter, SequentialSampler +from ...core.dataset import DataSet +from ...core.batch import DataSetIter +from ...core.sampler import SequentialSampler from ...core.utils import _move_model_to_device, _get_model_device diff --git a/reproduction/matching/matching.py b/reproduction/matching/matching.py new file mode 100644 index 00000000..52c1c3b5 --- /dev/null +++ b/reproduction/matching/matching.py @@ -0,0 +1,44 @@ +import os + +import torch + +from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric + +from fastNLP.io.dataset_loader import MatchingLoader + +from reproduction.matching.model.bert import BertForNLI + + +# bert_dirs = 'path/to/bert/dir' +bert_dirs = '/remote-home/ygxu/BERT/BERT_English_uncased_L-12_H-768_A_12' + +# load data set +data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process( + {#'train': './data/snli/snli_1.0_train.jsonl', + 'dev': './data/snli/snli_1.0_dev.jsonl', + 'test': './data/snli/snli_1.0_test.jsonl'} +) + +print('successfully load data sets!') + + +model = BertForNLI(bert_dir=bert_dirs) + +trainer = Trainer(train_data=data_info.datasets['dev'], model=model, + optimizer=Adam(lr=2e-5, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, + dev_data=data_info.datasets['dev'], + metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], + check_code_level=-1) +trainer.train(load_best_model=True) + +tester = Tester( + data=data_info.datasets['test'], + model=model, + metrics=AccuracyMetric(), + batch_size=torch.cuda.device_count() * 12, + device=[i for i in range(torch.cuda.device_count())], +) +tester.test() + + diff --git a/reproduction/matching/snli.py b/reproduction/matching/snli.py deleted file mode 100644 index d7f392bd..00000000 --- a/reproduction/matching/snli.py +++ /dev/null @@ -1,88 +0,0 @@ -import os - -import torch - -from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric - -from reproduction.matching.data.SNLIDataLoader import SNLILoader -from legacy.component.bert_tokenizer import BertTokenizer -from reproduction.matching.model.bert import BertForNLI - - -def preprocess_data(data: DataSet, bert_dir): - """ - preprocess data set to bert-need data set. - :param data: - :param bert_dir: - :return: - """ - tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt')) - - vocab = Vocabulary(padding=None, unknown=None) - with open(os.path.join(bert_dir, 'vocab.txt')) as f: - lines = f.readlines() - vocab_list = [] - for line in lines: - vocab_list.append(line.strip()) - vocab.add_word_lst(vocab_list) - vocab.build_vocab() - vocab.padding = '[PAD]' - vocab.unknown = '[UNK]' - - for i in range(2): - data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), - new_field_name=Const.INPUTS(i)) - data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], - new_field_name=Const.INPUT) - data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), - new_field_name=Const.INPUT_LENS(0)) - data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) - - max_len = 512 - data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) - data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) - data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) - data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) - - target_vocab = Vocabulary(padding=None, unknown=None) - target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment']) - target_vocab.build_vocab() - data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) - - data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET) - data.set_target(Const.TARGET) - - return data - - -bert_dirs = 'path/to/bert/dir' - -# load raw data set -train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl') -dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl') -test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl') - -print('successfully load data sets!') - -train_data = preprocess_data(train_data, bert_dirs) -dev_data = preprocess_data(dev_data, bert_dirs) -test_data = preprocess_data(test_data, bert_dirs) - -model = BertForNLI(bert_dir=bert_dirs) - -trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()), - batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data, - metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], - check_code_level=-1) -trainer.train(load_best_model=True) - -tester = Tester( - data=test_data, - model=model, - metrics=AccuracyMetric(), - batch_size=torch.cuda.device_count() * 12, - device=[i for i in range(torch.cuda.device_count())], -) -tester.test() - -