From 39388567ad7e0fd39fa39a993e8ddeaa6e5f4ff7 Mon Sep 17 00:00:00 2001 From: xuyige Date: Mon, 17 Jun 2019 21:48:18 +0800 Subject: [PATCH 1/2] update matching.py --- fastNLP/io/__init__.py | 5 +- fastNLP/io/dataset_loader.py | 163 ++++++++++++++++++++++++++- fastNLP/modules/encoder/__init__.py | 9 +- fastNLP/modules/encoder/_bert.py | 2 +- fastNLP/modules/encoder/embedding.py | 14 ++- reproduction/matching/matching.py | 44 ++++++++ reproduction/matching/snli.py | 88 --------------- 7 files changed, 229 insertions(+), 96 deletions(-) create mode 100644 reproduction/matching/matching.py delete mode 100644 reproduction/matching/snli.py diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index c8d6a441..83425ff7 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -16,6 +16,7 @@ __all__ = [ 'CSVLoader', 'JsonLoader', 'ConllLoader', + 'MatchingLoader', 'SNLILoader', 'SSTLoader', 'PeopleDailyCorpusLoader', @@ -26,6 +27,6 @@ __all__ = [ ] from .embed_loader import EmbedLoader -from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ - PeopleDailyCorpusLoader, Conll2003Loader +from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, MatchingLoader,\ + SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader from .model_io import ModelLoader, ModelSaver diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index e366c6ea..0595ad46 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -16,19 +16,24 @@ __all__ = [ 'CSVLoader', 'JsonLoader', 'ConllLoader', + 'MatchingLoader', 'SNLILoader', 'SSTLoader', 'PeopleDailyCorpusLoader', 'Conll2003Loader', ] +import os from nltk import Tree +from typing import Union, Dict +from ..core.vocabulary import Vocabulary from ..core.dataset import DataSet from ..core.instance import Instance from .file_reader import _read_csv, _read_json, _read_conll -from .base_loader import DataSetLoader +from .base_loader import DataSetLoader, DataInfo from .data_loader.sst import SSTLoader from ..core.const import Const +from ..modules.encoder._bert import BertTokenizer class PeopleDailyCorpusLoader(DataSetLoader): @@ -244,6 +249,162 @@ class JsonLoader(DataSetLoader): return ds +class MatchingLoader(DataSetLoader): + """ + 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` + + 读取Matching数据集,根据数据集做预处理并返回DataInfo。 + + 数据来源: + SNLI: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + """ + + def __init__(self, data_format: str='snli', for_model: str='esim', bert_dir=None): + super(MatchingLoader, self).__init__() + self.data_format = data_format.lower() + self.for_model = for_model.lower() + self.bert_dir = bert_dir + + def _load(self, path: str) -> DataSet: + raise NotImplementedError + + def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: + if isinstance(paths, str): + paths = {'train': paths} + + data_set = {} + for n, p in paths.items(): + if self.data_format == 'snli': + data = self._load_snli(p) + else: + raise RuntimeError(f'Your data format is {self.data_format}, ' + f'Please choose data format from [snli]') + + if self.for_model == 'esim': + data = self._for_esim(data) + elif self.for_model == 'bert': + data = self._for_bert(data, self.bert_dir) + else: + raise RuntimeError(f'Your model is {self.data_format}, ' + f'Please choose from [esim, bert]') + + data_set[n] = data + print(f'successfully load {n} set!') + + if not hasattr(self, 'vocab'): + raise RuntimeError(f'There is NOT vocab attribute built!') + if not hasattr(self, 'label_vocab'): + raise RuntimeError(f'There is NOT label vocab attribute built!') + + if self.for_model != 'bert': + from fastNLP.modules.encoder.embedding import StaticEmbedding + embedding = StaticEmbedding(self.vocab, model_dir_or_name='en') + + data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, + embeddings={'glove': embedding} if self.for_model != 'bert' else None, + datasets=data_set) + + return data_info + + @staticmethod + def _load_snli(path: str) -> DataSet: + """ + 读取SNLI数据集 + + 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + :param str path: 数据集路径 + :return: + """ + raw_ds = JsonLoader( + fields={ + 'sentence1_parse': Const.INPUTS(0), + 'sentence2_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + )._load(path) + return raw_ds + + def _for_esim(self, raw_ds: DataSet): + if self.data_format == 'snli' or self.data_format == 'mnli': + def parse_tree(x): + t = Tree.fromstring(x) + return t.leaves() + + raw_ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) + raw_ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + raw_ds.drop(lambda x: x[Const.TARGET] == '-') + + if not hasattr(self, 'vocab'): + self.vocab = Vocabulary().from_dataset(raw_ds, [Const.INPUTS(0), Const.INPUTS(1)]) + if not hasattr(self, 'label_vocab'): + self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) + + raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) + raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) + raw_ds.apply(lambda ins: self.label_vocab.to_index(Const.TARGET), new_field_name=Const.TARGET) + + raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1)) + raw_ds.set_target(Const.TARGET) + + return raw_ds + + def _for_bert(self, raw_ds: DataSet, bert_dir: str): + if self.data_format == 'snli' or self.data_format == 'mnli': + def parse_tree(x): + t = Tree.fromstring(x) + return t.leaves() + + raw_ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) + raw_ds.apply(lambda ins: parse_tree( + ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + raw_ds.drop(lambda x: x[Const.TARGET] == '-') + + tokenizer = BertTokenizer.from_pretrained(bert_dir) + + vocab = Vocabulary(padding=None, unknown=None) + with open(os.path.join(bert_dir, 'vocab.txt')) as f: + lines = f.readlines() + vocab_list = [] + for line in lines: + vocab_list.append(line.strip()) + vocab.add_word_lst(vocab_list) + vocab.build_vocab() + vocab.padding = '[PAD]' + vocab.unknown = '[UNK]' + + if not hasattr(self, 'vocab'): + self.vocab = vocab + else: + for w, idx in self.vocab: + if vocab[w] != idx: + raise AttributeError(f"{self.__class__.__name__} has ") + + for i in range(2): + raw_ds.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), new_field_name=Const.INPUTS(i)) + raw_ds.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], + new_field_name=Const.INPUT) + raw_ds.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), + new_field_name=Const.INPUT_LENS(0)) + raw_ds.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) + + max_len = 512 + raw_ds.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) + raw_ds.apply(lambda x: [self.vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) + raw_ds.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) + raw_ds.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) + + if not hasattr(self, 'label_vocab'): + self.label_vocab = Vocabulary(padding=None, unknown=None) + self.label_vocab.from_dataset(raw_ds, field_name=Const.TARGET) + raw_ds.apply(lambda x: self.label_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) + + raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) + raw_ds.set_target(Const.TARGET) + + class SNLILoader(JsonLoader): """ 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index bdc4cbf3..4be75f20 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -7,6 +7,12 @@ __all__ = [ "ConvMaxpool", "Embedding", + "StaticEmbedding", + "ElmoEmbedding", + "BertEmbedding", + "StackEmbedding", + "LSTMCharEmbedding", + "CNNCharEmbedding", "LSTM", @@ -21,7 +27,8 @@ __all__ = [ from .bert import BertModel from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .conv_maxpool import ConvMaxpool -from .embedding import Embedding +from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ + StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding from .lstm import LSTM from .star_transformer import StarTransformer from .transformer import TransformerEncoder diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py index fc62ea9c..1423f333 100644 --- a/fastNLP/modules/encoder/_bert.py +++ b/fastNLP/modules/encoder/_bert.py @@ -9,7 +9,7 @@ import torch from torch import nn -from ... import Vocabulary +from ...core.vocabulary import Vocabulary import collections import os diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 7279a372..f956aae7 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -1,10 +1,16 @@ __all__ = [ - "Embedding" + "Embedding", + "StaticEmbedding", + "ElmoEmbedding", + "BertEmbedding", + "StackEmbedding", + "LSTMCharEmbedding", + "CNNCharEmbedding", ] import torch.nn as nn from ..utils import get_embeddings from .lstm import LSTM -from ... import Vocabulary +from ...core.vocabulary import Vocabulary from abc import abstractmethod import torch from ...io import EmbedLoader @@ -15,7 +21,9 @@ from ...io.file_utils import cached_path, _get_base_url from ._bert import _WordBertModel from typing import List -from ... import DataSet, DataSetIter, SequentialSampler +from ...core.dataset import DataSet +from ...core.batch import DataSetIter +from ...core.sampler import SequentialSampler from ...core.utils import _move_model_to_device, _get_model_device diff --git a/reproduction/matching/matching.py b/reproduction/matching/matching.py new file mode 100644 index 00000000..52c1c3b5 --- /dev/null +++ b/reproduction/matching/matching.py @@ -0,0 +1,44 @@ +import os + +import torch + +from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric + +from fastNLP.io.dataset_loader import MatchingLoader + +from reproduction.matching.model.bert import BertForNLI + + +# bert_dirs = 'path/to/bert/dir' +bert_dirs = '/remote-home/ygxu/BERT/BERT_English_uncased_L-12_H-768_A_12' + +# load data set +data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process( + {#'train': './data/snli/snli_1.0_train.jsonl', + 'dev': './data/snli/snli_1.0_dev.jsonl', + 'test': './data/snli/snli_1.0_test.jsonl'} +) + +print('successfully load data sets!') + + +model = BertForNLI(bert_dir=bert_dirs) + +trainer = Trainer(train_data=data_info.datasets['dev'], model=model, + optimizer=Adam(lr=2e-5, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, + dev_data=data_info.datasets['dev'], + metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], + check_code_level=-1) +trainer.train(load_best_model=True) + +tester = Tester( + data=data_info.datasets['test'], + model=model, + metrics=AccuracyMetric(), + batch_size=torch.cuda.device_count() * 12, + device=[i for i in range(torch.cuda.device_count())], +) +tester.test() + + diff --git a/reproduction/matching/snli.py b/reproduction/matching/snli.py deleted file mode 100644 index d7f392bd..00000000 --- a/reproduction/matching/snli.py +++ /dev/null @@ -1,88 +0,0 @@ -import os - -import torch - -from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric - -from reproduction.matching.data.SNLIDataLoader import SNLILoader -from legacy.component.bert_tokenizer import BertTokenizer -from reproduction.matching.model.bert import BertForNLI - - -def preprocess_data(data: DataSet, bert_dir): - """ - preprocess data set to bert-need data set. - :param data: - :param bert_dir: - :return: - """ - tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt')) - - vocab = Vocabulary(padding=None, unknown=None) - with open(os.path.join(bert_dir, 'vocab.txt')) as f: - lines = f.readlines() - vocab_list = [] - for line in lines: - vocab_list.append(line.strip()) - vocab.add_word_lst(vocab_list) - vocab.build_vocab() - vocab.padding = '[PAD]' - vocab.unknown = '[UNK]' - - for i in range(2): - data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), - new_field_name=Const.INPUTS(i)) - data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], - new_field_name=Const.INPUT) - data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), - new_field_name=Const.INPUT_LENS(0)) - data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) - - max_len = 512 - data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) - data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) - data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) - data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) - - target_vocab = Vocabulary(padding=None, unknown=None) - target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment']) - target_vocab.build_vocab() - data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) - - data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET) - data.set_target(Const.TARGET) - - return data - - -bert_dirs = 'path/to/bert/dir' - -# load raw data set -train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl') -dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl') -test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl') - -print('successfully load data sets!') - -train_data = preprocess_data(train_data, bert_dirs) -dev_data = preprocess_data(dev_data, bert_dirs) -test_data = preprocess_data(test_data, bert_dirs) - -model = BertForNLI(bert_dir=bert_dirs) - -trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()), - batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data, - metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], - check_code_level=-1) -trainer.train(load_best_model=True) - -tester = Tester( - data=test_data, - model=model, - metrics=AccuracyMetric(), - batch_size=torch.cuda.device_count() * 12, - device=[i for i in range(torch.cuda.device_count())], -) -tester.test() - - From 93620e76edf0162f8b9d8f844728dfd1c203e58d Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 18 Jun 2019 02:04:53 +0800 Subject: [PATCH 2/2] update framework of matching --- fastNLP/io/dataset_loader.py | 25 ++-- fastNLP/modules/encoder/bert.py | 4 +- fastNLP/modules/encoder/embedding.py | 11 +- reproduction/matching/matching.py | 26 ++-- reproduction/matching/model/esim.py | 182 +++++++++++++++++++++++++++ 5 files changed, 221 insertions(+), 27 deletions(-) create mode 100644 reproduction/matching/model/esim.py diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index c63ff2f4..b0bf2e60 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -269,7 +269,7 @@ class MatchingLoader(DataSetLoader): def _load(self, path: str) -> DataSet: raise NotImplementedError - def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: + def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo: if isinstance(paths, str): paths = {'train': paths} @@ -289,6 +289,13 @@ class MatchingLoader(DataSetLoader): raise RuntimeError(f'Your model is {self.data_format}, ' f'Please choose from [esim, bert]') + if input_field is not None: + if isinstance(input_field, str): + data.set_input(input_field) + elif isinstance(input_field, list): + for field in input_field: + data.set_input(field) + data_set[n] = data print(f'successfully load {n} set!') @@ -298,11 +305,11 @@ class MatchingLoader(DataSetLoader): raise RuntimeError(f'There is NOT label vocab attribute built!') if self.for_model != 'bert': - from fastNLP.modules.encoder.embedding import StaticEmbedding - embedding = StaticEmbedding(self.vocab, model_dir_or_name='en') + from fastNLP.modules.encoder.embedding import ElmoEmbedding + embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2') data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, - embeddings={'glove': embedding} if self.for_model != 'bert' else None, + embeddings={'elmo': embedding} if self.for_model != 'bert' else None, datasets=data_set) return data_info @@ -338,15 +345,17 @@ class MatchingLoader(DataSetLoader): raw_ds.drop(lambda x: x[Const.TARGET] == '-') if not hasattr(self, 'vocab'): - self.vocab = Vocabulary().from_dataset(raw_ds, [Const.INPUTS(0), Const.INPUTS(1)]) + self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)]) if not hasattr(self, 'label_vocab'): self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) - raw_ds.apply(lambda ins: self.label_vocab.to_index(Const.TARGET), new_field_name=Const.TARGET) + raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET) + raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0)) + raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1)) - raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1)) + raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)) raw_ds.set_target(Const.TARGET) return raw_ds @@ -405,6 +414,8 @@ class MatchingLoader(DataSetLoader): raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) raw_ds.set_target(Const.TARGET) + return raw_ds + class SNLILoader(JsonLoader): """ diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index e9739c28..4948d022 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -2,9 +2,9 @@ import os from torch import nn import torch -from ...core import Vocabulary +from ...core.vocabulary import Vocabulary from ...io.file_utils import _get_base_url, cached_path -from ._bert import _WordPieceBertModel +from ._bert import _WordPieceBertModel, BertModel class BertWordPieceEncoder(nn.Module): diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 7fd85578..9c1bf35f 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -152,6 +152,8 @@ class StaticEmbedding(TokenEmbedding): Example:: + >>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50') + :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding @@ -311,8 +313,7 @@ class ElmoEmbedding(ContextualEmbedding): Example:: - >>> - >>> + >>> embedding = ElmoEmbedding(vocab, model_dir_or_name='en', layers='2', requires_grad=True) :param vocab: 词表 :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, @@ -403,7 +404,7 @@ class BertEmbedding(ContextualEmbedding): Example:: - >>> + >>> embedding = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1') :param fastNLP.Vocabulary vocab: 词表 @@ -513,7 +514,7 @@ class CNNCharEmbedding(TokenEmbedding): Example:: - >>> + >>> cnn_char_embed = CNNCharEmbedding(vocab) :param vocab: 词表 @@ -647,7 +648,7 @@ class LSTMCharEmbedding(TokenEmbedding): Example:: - >>> + >>> lstm_char_embed = LSTMCharEmbedding(vocab) :param vocab: 词表 :param embed_size: embedding的大小。默认值为50. diff --git a/reproduction/matching/matching.py b/reproduction/matching/matching.py index 52c1c3b5..8251b3bc 100644 --- a/reproduction/matching/matching.py +++ b/reproduction/matching/matching.py @@ -2,31 +2,31 @@ import os import torch -from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric +from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const from fastNLP.io.dataset_loader import MatchingLoader from reproduction.matching.model.bert import BertForNLI +from reproduction.matching.model.esim import ESIMModel -# bert_dirs = 'path/to/bert/dir' -bert_dirs = '/remote-home/ygxu/BERT/BERT_English_uncased_L-12_H-768_A_12' +bert_dirs = 'path/to/bert/dir' # load data set -data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process( - {#'train': './data/snli/snli_1.0_train.jsonl', +# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(... +data_info = MatchingLoader(data_format='snli', for_model='esim').process( + {'train': './data/snli/snli_1.0_train.jsonl', 'dev': './data/snli/snli_1.0_dev.jsonl', - 'test': './data/snli/snli_1.0_test.jsonl'} + 'test': './data/snli/snli_1.0_test.jsonl'}, + input_field=[Const.TARGET] ) -print('successfully load data sets!') +# model = BertForNLI(bert_dir=bert_dirs) +model = ESIMModel(data_info.embeddings['elmo'],) - -model = BertForNLI(bert_dir=bert_dirs) - -trainer = Trainer(train_data=data_info.datasets['dev'], model=model, - optimizer=Adam(lr=2e-5, model_params=model.parameters()), - batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, +trainer = Trainer(train_data=data_info.datasets['train'], model=model, + optimizer=Adam(lr=1e-4, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1, dev_data=data_info.datasets['dev'], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1) diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py new file mode 100644 index 00000000..0551bbdb --- /dev/null +++ b/reproduction/matching/model/esim.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import CrossEntropyLoss + +from fastNLP.models import BaseModel +from fastNLP.modules.encoder.embedding import TokenEmbedding +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.core.const import Const +from fastNLP.core.utils import seq_len_to_mask + + +class ESIMModel(BaseModel): + def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, + dropout_embed=0.1): + super(ESIMModel, self).__init__() + + self.embedding = init_embedding + self.dropout_embed = EmbedDropout(p=dropout_embed) + if hidden_size is None: + hidden_size = self.embedding.embed_size + self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) + # self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) + + self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), + nn.Linear(8 * hidden_size, hidden_size), + nn.ReLU()) + nn.init.xavier_uniform_(self.interfere[1].weight.data) + self.bi_attention = SoftmaxAttention() + + self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) + # self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True) + + self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), + nn.Linear(8 * hidden_size, hidden_size), + nn.Tanh(), + nn.Dropout(p=dropout_rate), + nn.Linear(hidden_size, num_labels)) + nn.init.xavier_uniform_(self.classifier[1].weight.data) + nn.init.xavier_uniform_(self.classifier[4].weight.data) + + def forward(self, words1, words2, seq_len1, seq_len2, target=None): + mask1 = seq_len_to_mask(seq_len1) + mask2 = seq_len_to_mask(seq_len2) + a0 = self.embedding(words1) # B * len * emb_dim + b0 = self.embedding(words2) + a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) + a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] + b = self.rnn(b0, mask2.byte()) + + ai, bi = self.bi_attention(a, mask1, b, mask2) + + a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] + b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) + a_f = self.interfere(a_) + b_f = self.interfere(b_) + + a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] + b_h = self.rnn_high(b_f, mask2.byte()) + + a_avg = self.mean_pooling(a_h, mask1, dim=1) + a_max, _ = self.max_pooling(a_h, mask1, dim=1) + b_avg = self.mean_pooling(b_h, mask2, dim=1) + b_max, _ = self.max_pooling(b_h, mask2, dim=1) + + out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] + logits = torch.tanh(self.classifier(out)) + + if target is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits, target) + + return {Const.LOSS: loss, Const.OUTPUT: logits} + else: + return {Const.OUTPUT: logits} + + def predict(self, **kwargs): + return self.forward(**kwargs) + + # input [batch_size, len , hidden] + # mask [batch_size, len] (111...00) + @staticmethod + def mean_pooling(input, mask, dim=1): + masks = mask.view(mask.size(0), mask.size(1), -1).float() + return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) + + @staticmethod + def max_pooling(input, mask, dim=1): + my_inf = 10e12 + masks = mask.view(mask.size(0), mask.size(1), -1) + masks = masks.expand(-1, -1, input.size(2)).float() + return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) + + +class EmbedDropout(nn.Dropout): + + def forward(self, sequences_batch): + ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) + dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) + return dropout_mask.unsqueeze(1) * sequences_batch + + +class BiRNN(nn.Module): + def __init__(self, input_size, hidden_size, dropout_rate=0.3): + super(BiRNN, self).__init__() + self.dropout_rate = dropout_rate + self.rnn = nn.LSTM(input_size, hidden_size, + num_layers=1, + bidirectional=True, + batch_first=True) + + def forward(self, x, x_mask): + # Sort x + lengths = x_mask.data.eq(1).long().sum(1).squeeze() + _, idx_sort = torch.sort(lengths, dim=0, descending=True) + _, idx_unsort = torch.sort(idx_sort, dim=0) + lengths = list(lengths[idx_sort]) + + x = x.index_select(0, idx_sort) + # Pack it up + rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) + # Apply dropout to input + if self.dropout_rate > 0: + dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) + rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) + output = self.rnn(rnn_input)[0] + # Unpack everything + output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] + output = output.index_select(0, idx_unsort) + if output.size(1) != x_mask.size(1): + padding = torch.zeros(output.size(0), + x_mask.size(1) - output.size(1), + output.size(2)).type(output.data.type()) + output = torch.cat([output, padding], 1) + return output + + +def masked_softmax(tensor, mask): + tensor_shape = tensor.size() + reshaped_tensor = tensor.view(-1, tensor_shape[-1]) + + # Reshape the mask so it matches the size of the input tensor. + while mask.dim() < tensor.dim(): + mask = mask.unsqueeze(1) + mask = mask.expand_as(tensor).contiguous().float() + reshaped_mask = mask.view(-1, mask.size()[-1]) + result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) + result = result * reshaped_mask + # 1e-13 is added to avoid divisions by zero. + result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) + return result.view(*tensor_shape) + + +def weighted_sum(tensor, weights, mask): + w_sum = weights.bmm(tensor) + while mask.dim() < w_sum.dim(): + mask = mask.unsqueeze(1) + mask = mask.transpose(-1, -2) + mask = mask.expand_as(w_sum).contiguous().float() + return w_sum * mask + + +class SoftmaxAttention(nn.Module): + + def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): + similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) + .contiguous()) + + prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) + hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) + .contiguous(), + premise_mask) + + attended_premises = weighted_sum(hypothesis_batch, + prem_hyp_attn, + premise_mask) + attended_hypotheses = weighted_sum(premise_batch, + hyp_prem_attn, + hypothesis_mask) + + return attended_premises, attended_hypotheses \ No newline at end of file