From d70aa96e4ce056a5627243e361935cccbc6f46c0 Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 9 Jul 2019 01:01:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=A7=E5=B9=85=E5=BA=A6=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=EF=BC=9A1=E3=80=81=E6=9B=B4=E6=96=B0requirements=EF=BC=9B2?= =?UTF-8?q?=E3=80=81=E5=B0=86modules.aggregator=E7=9A=84=E5=86=85=E5=AE=B9?= =?UTF-8?q?=E7=A7=BB=E8=87=B3modules.encoder=EF=BC=9B3=E3=80=81=E5=B0=86SQ?= =?UTF-8?q?uADMetric=E9=87=8D=E5=91=BD=E5=90=8D=E4=B8=BAExtractiveQAMetric?= =?UTF-8?q?=EF=BC=9B4=E3=80=81=E6=9B=B4=E6=96=B0reproduction=E7=9A=84READM?= =?UTF-8?q?E=EF=BC=9B5=E3=80=81=E5=B0=86reproduction/text=5Fclassification?= =?UTF-8?q?=E7=9A=84dataloader=E7=A7=BB=E5=8A=A8=E5=88=B0fastNLP.io.data?= =?UTF-8?q?=5Floader=E5=B9=B6=E5=81=9A=E9=80=82=E9=85=8D=E6=80=A7=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/__init__.py | 3 +- fastNLP/core/__init__.py | 2 +- fastNLP/core/metrics.py | 13 +- fastNLP/io/data_loader/__init__.py | 20 +- fastNLP/io/data_loader/imdb.py | 96 +++++++++ fastNLP/io/data_loader/matching.py | 204 +----------------- fastNLP/io/data_loader/mnli.py | 60 ++++++ fastNLP/io/data_loader/mtl.py | 65 ++++++ fastNLP/io/data_loader/qnli.py | 45 ++++ fastNLP/io/data_loader/quora.py | 32 +++ fastNLP/io/data_loader/rte.py | 45 ++++ fastNLP/io/data_loader/snli.py | 44 ++++ fastNLP/io/data_loader/sst.py | 84 +++++++- fastNLP/io/data_loader/yelp.py | 126 +++++++++++ fastNLP/modules/aggregator/__init__.py | 14 -- .../{aggregator => encoder}/attention.py | 4 +- .../{aggregator => encoder}/pooling.py | 0 fastNLP/modules/encoder/transformer.py | 2 +- .../{readme.md => README.md} | 10 +- reproduction/matching/README.md | 10 +- requirements.txt | 2 + 21 files changed, 632 insertions(+), 249 deletions(-) create mode 100644 fastNLP/io/data_loader/imdb.py create mode 100644 fastNLP/io/data_loader/mnli.py create mode 100644 fastNLP/io/data_loader/mtl.py create mode 100644 fastNLP/io/data_loader/qnli.py create mode 100644 fastNLP/io/data_loader/quora.py create mode 100644 fastNLP/io/data_loader/rte.py create mode 100644 fastNLP/io/data_loader/snli.py create mode 100644 fastNLP/io/data_loader/yelp.py delete mode 100644 fastNLP/modules/aggregator/__init__.py rename fastNLP/modules/{aggregator => encoder}/attention.py (98%) rename fastNLP/modules/{aggregator => encoder}/pooling.py (100%) rename reproduction/coreference_resolution/{readme.md => README.md} (85%) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index e666f65f..12d421a2 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -37,7 +37,7 @@ __all__ = [ "AccuracyMetric", "SpanFPreRecMetric", - "SQuADMetric", + "ExtractiveQAMetric", "Optimizer", "SGD", @@ -61,3 +61,4 @@ __version__ = '0.4.0' from .core import * from . import models from . import modules +from .io import data_loader diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 792bff66..efc83017 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -21,7 +21,7 @@ from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder from .instance import Instance from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward -from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric +from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric from .optimizer import Optimizer, SGD, Adam from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .tester import Tester diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index d54bf8ec..887a7abe 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -6,7 +6,7 @@ __all__ = [ "MetricBase", "AccuracyMetric", "SpanFPreRecMetric", - "SQuADMetric" + "ExtractiveQAMetric" ] import inspect @@ -24,6 +24,7 @@ from .utils import seq_len_to_mask from .vocabulary import Vocabulary from abc import abstractmethod + class MetricBase(object): """ 所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 @@ -735,11 +736,11 @@ def _pred_topk(y_prob, k=1): return y_pred_topk, y_prob_topk -class SQuADMetric(MetricBase): - r""" - 别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` +class ExtractiveQAMetric(MetricBase): + """ + 别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` - SQuAD数据集metric + 抽取式QA(如SQuAD)的metric. :param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` :param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` @@ -755,7 +756,7 @@ class SQuADMetric(MetricBase): def __init__(self, pred1=None, pred2=None, target1=None, target2=None, beta=1, right_open=True, print_predict_stat=False): - super(SQuADMetric, self).__init__() + super(ExtractiveQAMetric, self).__init__() self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) diff --git a/fastNLP/io/data_loader/__init__.py b/fastNLP/io/data_loader/__init__.py index 6f4dd973..893ef0e2 100644 --- a/fastNLP/io/data_loader/__init__.py +++ b/fastNLP/io/data_loader/__init__.py @@ -4,16 +4,26 @@ 这些模块的使用方法如下: """ __all__ = [ - 'SSTLoader', - + 'IMDBLoader', 'MatchingLoader', - 'SNLILoader', 'MNLILoader', + 'MTL16Loader', 'QNLILoader', 'QuoraLoader', 'RTELoader', + 'SSTLoader', + 'SNLILoader', + 'YelpLoader', ] + +from .imdb import IMDBLoader +from .matching import MatchingLoader +from .mnli import MNLILoader +from .mtl import MTL16Loader +from .qnli import QNLILoader +from .quora import QuoraLoader +from .rte import RTELoader +from .snli import SNLILoader from .sst import SSTLoader -from .matching import MatchingLoader, SNLILoader, \ - MNLILoader, QNLILoader, QuoraLoader, RTELoader +from .yelp import YelpLoader diff --git a/fastNLP/io/data_loader/imdb.py b/fastNLP/io/data_loader/imdb.py new file mode 100644 index 00000000..b4c2c1f9 --- /dev/null +++ b/fastNLP/io/data_loader/imdb.py @@ -0,0 +1,96 @@ + +from typing import Union, Dict + +from ..embed_loader import EmbeddingOption, EmbedLoader +from ..base_loader import DataSetLoader, DataInfo +from ...core.vocabulary import VocabularyOption, Vocabulary +from ...core.dataset import DataSet +from ...core.instance import Instance +from ...core.const import Const + +from ..utils import get_tokenizer + + +class IMDBLoader(DataSetLoader): + """ + 读取IMDB数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + self.tokenizer = get_tokenizer() + + def _load(self, path): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = self.tokenizer(parts[1].lower()) + dataset.append(Instance(words=words, target=target)) + + if len(dataset) == 0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + char_level_op=False): + + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars = [] + for word in words: + word = word.lower() + for char in word: + chars.append(char) + chars.append('') + chars.pop() + return chars + + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + Const.INPUT: src_vocab, + Const.TARGET: tgt_vocab + } + + info.datasets = datasets + + for name, dataset in info.datasets.items(): + dataset.set_input(Const.INPUT) + dataset.set_target(Const.TARGET) + + return info + + + diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 3d131bcb..771f2748 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -5,14 +5,13 @@ from typing import Union, Dict from ...core.const import Const from ...core.vocabulary import Vocabulary from ..base_loader import DataInfo, DataSetLoader -from ..dataset_loader import JsonLoader, CSVLoader from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ...modules.encoder._bert import BertTokenizer class MatchingLoader(DataSetLoader): """ - 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` + 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` 读取Matching任务的数据集 @@ -227,204 +226,3 @@ class MatchingLoader(DataSetLoader): data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) return data_info - - -class SNLILoader(MatchingLoader, JsonLoader): - """ - 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` - - 读取SNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip - """ - - def __init__(self, paths: dict=None): - fields = { - 'sentence1_binary_parse': Const.INPUTS(0), - 'sentence2_binary_parse': Const.INPUTS(1), - 'gold_label': Const.TARGET, - } - paths = paths if paths is not None else { - 'train': 'snli_1.0_train.jsonl', - 'dev': 'snli_1.0_dev.jsonl', - 'test': 'snli_1.0_test.jsonl'} - MatchingLoader.__init__(self, paths=paths) - JsonLoader.__init__(self, fields=fields) - - def _load(self, path): - ds = JsonLoader._load(self, path) - - parentheses_table = str.maketrans({'(': None, ')': None}) - - ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), - new_field_name=Const.INPUTS(0)) - ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), - new_field_name=Const.INPUTS(1)) - ds.drop(lambda x: x[Const.TARGET] == '-') - return ds - - -class RTELoader(MatchingLoader, CSVLoader): - """ - 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` - - 读取RTE数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: - """ - - def __init__(self, paths: dict=None): - paths = paths if paths is not None else { - 'train': 'train.tsv', - 'dev': 'dev.tsv', - 'test': 'test.tsv' # test set has not label - } - MatchingLoader.__init__(self, paths=paths) - self.fields = { - 'sentence1': Const.INPUTS(0), - 'sentence2': Const.INPUTS(1), - 'label': Const.TARGET, - } - CSVLoader.__init__(self, sep='\t') - - def _load(self, path): - ds = CSVLoader._load(self, path) - - for k, v in self.fields.items(): - if v in ds.get_field_names(): - ds.rename_field(k, v) - for fields in ds.get_all_fields(): - if Const.INPUT in fields: - ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) - - return ds - - -class QNLILoader(MatchingLoader, CSVLoader): - """ - 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` - - 读取QNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: - """ - - def __init__(self, paths: dict=None): - paths = paths if paths is not None else { - 'train': 'train.tsv', - 'dev': 'dev.tsv', - 'test': 'test.tsv' # test set has not label - } - MatchingLoader.__init__(self, paths=paths) - self.fields = { - 'question': Const.INPUTS(0), - 'sentence': Const.INPUTS(1), - 'label': Const.TARGET, - } - CSVLoader.__init__(self, sep='\t') - - def _load(self, path): - ds = CSVLoader._load(self, path) - - for k, v in self.fields.items(): - if v in ds.get_field_names(): - ds.rename_field(k, v) - for fields in ds.get_all_fields(): - if Const.INPUT in fields: - ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) - - return ds - - -class MNLILoader(MatchingLoader, CSVLoader): - """ - 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` - - 读取MNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: - """ - - def __init__(self, paths: dict=None): - paths = paths if paths is not None else { - 'train': 'train.tsv', - 'dev_matched': 'dev_matched.tsv', - 'dev_mismatched': 'dev_mismatched.tsv', - 'test_matched': 'test_matched.tsv', - 'test_mismatched': 'test_mismatched.tsv', - # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', - # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', - - # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) - } - MatchingLoader.__init__(self, paths=paths) - CSVLoader.__init__(self, sep='\t') - self.fields = { - 'sentence1_binary_parse': Const.INPUTS(0), - 'sentence2_binary_parse': Const.INPUTS(1), - 'gold_label': Const.TARGET, - } - - def _load(self, path): - ds = CSVLoader._load(self, path) - - for k, v in self.fields.items(): - if k in ds.get_field_names(): - ds.rename_field(k, v) - - if Const.TARGET in ds.get_field_names(): - if ds[0][Const.TARGET] == 'hidden': - ds.delete_field(Const.TARGET) - - parentheses_table = str.maketrans({'(': None, ')': None}) - - ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), - new_field_name=Const.INPUTS(0)) - ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), - new_field_name=Const.INPUTS(1)) - if Const.TARGET in ds.get_field_names(): - ds.drop(lambda x: x[Const.TARGET] == '-') - return ds - - -class QuoraLoader(MatchingLoader, CSVLoader): - """ - 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` - - 读取MNLI数据集,读取的DataSet包含fields:: - - words1: list(str),第一句文本, premise - words2: list(str), 第二句文本, hypothesis - target: str, 真实标签 - - 数据来源: - """ - - def __init__(self, paths: dict=None): - paths = paths if paths is not None else { - 'train': 'train.tsv', - 'dev': 'dev.tsv', - 'test': 'test.tsv', - } - MatchingLoader.__init__(self, paths=paths) - CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) - - def _load(self, path): - ds = CSVLoader._load(self, path) - return ds diff --git a/fastNLP/io/data_loader/mnli.py b/fastNLP/io/data_loader/mnli.py new file mode 100644 index 00000000..48923736 --- /dev/null +++ b/fastNLP/io/data_loader/mnli.py @@ -0,0 +1,60 @@ + +from ...core import Const + +from .matching import MatchingLoader +from ..dataset_loader import CSVLoader + + +class MNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` + + 读取MNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev_matched': 'dev_matched.tsv', + 'dev_mismatched': 'dev_mismatched.tsv', + 'test_matched': 'test_matched.tsv', + 'test_mismatched': 'test_mismatched.tsv', + # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', + # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', + + # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t') + self.fields = { + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if k in ds.get_field_names(): + ds.rename_field(k, v) + + if Const.TARGET in ds.get_field_names(): + if ds[0][Const.TARGET] == 'hidden': + ds.delete_field(Const.TARGET) + + parentheses_table = str.maketrans({'(': None, ')': None}) + + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) + if Const.TARGET in ds.get_field_names(): + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds diff --git a/fastNLP/io/data_loader/mtl.py b/fastNLP/io/data_loader/mtl.py new file mode 100644 index 00000000..548a985b --- /dev/null +++ b/fastNLP/io/data_loader/mtl.py @@ -0,0 +1,65 @@ + +from typing import Union, Dict + +from ..base_loader import DataInfo +from ..dataset_loader import CSVLoader +from ...core.vocabulary import Vocabulary, VocabularyOption +from ...core.const import Const +from ..utils import check_dataloader_paths + + +class MTL16Loader(CSVLoader): + """ + 读取MTL16数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + 数据来源:https://pan.baidu.com/s/1c2L6vdA + + """ + + def __init__(self): + super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t') + + def _load(self, path): + dataset = super(MTL16Loader, self)._load(path) + dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT) + if len(dataset) == 0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None,): + + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) + src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) + tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) + + info.vocabs = { + Const.INPUT: src_vocab, + Const.TARGET: tgt_vocab + } + + info.datasets = datasets + + for name, dataset in info.datasets.items(): + dataset.set_input(Const.INPUT) + dataset.set_target(Const.TARGET) + + return info diff --git a/fastNLP/io/data_loader/qnli.py b/fastNLP/io/data_loader/qnli.py new file mode 100644 index 00000000..650c6be7 --- /dev/null +++ b/fastNLP/io/data_loader/qnli.py @@ -0,0 +1,45 @@ + +from ...core import Const + +from .matching import MatchingLoader +from ..dataset_loader import CSVLoader + + +class QNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader` + + 读取QNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv' # test set has not label + } + MatchingLoader.__init__(self, paths=paths) + self.fields = { + 'question': Const.INPUTS(0), + 'sentence': Const.INPUTS(1), + 'label': Const.TARGET, + } + CSVLoader.__init__(self, sep='\t') + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if k in ds.get_field_names(): + ds.rename_field(k, v) + for fields in ds.get_all_fields(): + if Const.INPUT in fields: + ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) + + return ds diff --git a/fastNLP/io/data_loader/quora.py b/fastNLP/io/data_loader/quora.py new file mode 100644 index 00000000..2c466a24 --- /dev/null +++ b/fastNLP/io/data_loader/quora.py @@ -0,0 +1,32 @@ + +from ...core import Const + +from .matching import MatchingLoader +from ..dataset_loader import CSVLoader + + +class QuoraLoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader` + + 读取MNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv', + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) + + def _load(self, path): + ds = CSVLoader._load(self, path) + return ds diff --git a/fastNLP/io/data_loader/rte.py b/fastNLP/io/data_loader/rte.py new file mode 100644 index 00000000..9bf05d60 --- /dev/null +++ b/fastNLP/io/data_loader/rte.py @@ -0,0 +1,45 @@ + +from ...core import Const + +from .matching import MatchingLoader +from ..dataset_loader import CSVLoader + + +class RTELoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader` + + 读取RTE数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv' # test set has not label + } + MatchingLoader.__init__(self, paths=paths) + self.fields = { + 'sentence1': Const.INPUTS(0), + 'sentence2': Const.INPUTS(1), + 'label': Const.TARGET, + } + CSVLoader.__init__(self, sep='\t') + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if k in ds.get_field_names(): + ds.rename_field(k, v) + for fields in ds.get_all_fields(): + if Const.INPUT in fields: + ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) + + return ds diff --git a/fastNLP/io/data_loader/snli.py b/fastNLP/io/data_loader/snli.py new file mode 100644 index 00000000..7c91ca86 --- /dev/null +++ b/fastNLP/io/data_loader/snli.py @@ -0,0 +1,44 @@ + +from ...core import Const + +from .matching import MatchingLoader +from ..dataset_loader import JsonLoader + + +class SNLILoader(MatchingLoader, JsonLoader): + """ + 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader` + + 读取SNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + """ + + def __init__(self, paths: dict=None): + fields = { + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + paths = paths if paths is not None else { + 'train': 'snli_1.0_train.jsonl', + 'dev': 'snli_1.0_dev.jsonl', + 'test': 'snli_1.0_test.jsonl'} + MatchingLoader.__init__(self, paths=paths) + JsonLoader.__init__(self, fields=fields) + + def _load(self, path): + ds = JsonLoader._load(self, path) + + parentheses_table = str.maketrans({'(': None, ')': None}) + + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index 8d0d005f..a7a35aee 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -1,19 +1,19 @@ -from typing import Iterable + +from typing import Union, Dict from nltk import Tree -import spacy + from ..base_loader import DataInfo, DataSetLoader +from ..dataset_loader import CSVLoader from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet +from ...core.const import Const from ...core.instance import Instance from ..utils import check_dataloader_paths, get_tokenizer class SSTLoader(DataSetLoader): - URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' - DATA_DIR = 'sst/' - """ - 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader` 读取SST数据集, DataSet包含fields:: @@ -26,6 +26,9 @@ class SSTLoader(DataSetLoader): :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ + URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' + DATA_DIR = 'sst/' + def __init__(self, subtree=False, fine_grained=False): self.subtree = subtree @@ -98,3 +101,72 @@ class SSTLoader(DataSetLoader): return info + +class SST2Loader(CSVLoader): + """ + 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', + """ + + def __init__(self): + super(SST2Loader, self).__init__(sep='\t') + self.tokenizer = get_tokenizer() + self.field = {'sentence': Const.INPUT, 'label': Const.TARGET} + + def _load(self, path: str) -> DataSet: + ds = super(SST2Loader, self)._load(path) + ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) + print("all count:", len(ds)) + return ds + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + char_level_op=False): + + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars = [] + for word in words: + word = word.lower() + for char in word: + chars.append(char) + chars.append('') + chars.pop() + return chars + + input_name, target_name = Const.INPUT, Const.TARGET + info.vocabs={} + + # 就分隔为char形式 + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) + src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) + tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) + + info.vocabs = { + Const.INPUT: src_vocab, + Const.TARGET: tgt_vocab + } + + info.datasets = datasets + + for name, dataset in info.datasets.items(): + dataset.set_input(Const.INPUT) + dataset.set_target(Const.TARGET) + + return info + diff --git a/fastNLP/io/data_loader/yelp.py b/fastNLP/io/data_loader/yelp.py new file mode 100644 index 00000000..1ac4421d --- /dev/null +++ b/fastNLP/io/data_loader/yelp.py @@ -0,0 +1,126 @@ + +import csv +from typing import Iterable + +from ...core.const import Const +from ...core import DataSet, Instance, Vocabulary +from ...core.vocabulary import VocabularyOption +from ..base_loader import DataInfo,DataSetLoader +from typing import Union, Dict +from ..utils import check_dataloader_paths, get_tokenizer + + +class YelpLoader(DataSetLoader): + """ + 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: + words: list(str), 需要分类的文本 + target: str, 文本的标签 + chars:list(str),未index的字符列表 + + 数据集:yelp_full/yelp_polarity + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + :param lower: 是否需要自动转小写,默认为False。 + """ + + def __init__(self, fine_grained=False, lower=False): + super(YelpLoader, self).__init__() + tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', + '4.0': 'positive', '5.0': 'very positive'} + if not fine_grained: + tag_v['1.0'] = tag_v['2.0'] + tag_v['5.0'] = tag_v['4.0'] + self.fine_grained = fine_grained + self.tag_v = tag_v + self.lower = lower + self.tokenizer = get_tokenizer() + + def _load(self, path): + ds = DataSet() + csv_reader = csv.reader(open(path, encoding='utf-8')) + all_count = 0 + real_count = 0 + for row in csv_reader: + all_count += 1 + if len(row) == 2: + target = self.tag_v[row[0] + ".0"] + words = clean_str(row[1], self.tokenizer, self.lower) + if len(words) != 0: + ds.append(Instance(words=words, target=target)) + real_count += 1 + print("all count:", all_count) + print("real count:", real_count) + return ds + + def process(self, paths: Union[str, Dict[str, str]], + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + char_level_op=False): + paths = check_dataloader_paths(paths) + info = DataInfo(datasets=self.load(paths)) + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + + def wordtochar(words): + chars = [] + for word in words: + word = word.lower() + for char in word: + chars.append(char) + chars.append('') + chars.pop() + return chars + + input_name, target_name = Const.INPUT, Const.TARGET + info.vocabs = {} + # 就分隔为char形式 + if char_level_op: + for dataset in info.datasets.values(): + dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) + else: + src_vocab.from_dataset(*_train_ds, field_name=input_name) + src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) + info.vocabs[input_name] = src_vocab + + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + + info.vocabs[target_name] = tgt_vocab + + info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) + + for name, dataset in info.datasets.items(): + dataset.set_input(Const.INPUT) + dataset.set_target(Const.TARGET) + + return info + + +def clean_str(sentence, tokenizer, char_lower=False): + """ + heavily borrowed from github + https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb + :param sentence: is a str + :return: + """ + if char_lower: + sentence = sentence.lower() + import re + nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') + words = tokenizer(sentence) + words_collection = [] + for word in words: + if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: + continue + tt = nonalpnum.split(word) + t = ''.join(tt) + if t != '': + words_collection.append(t) + + return words_collection + diff --git a/fastNLP/modules/aggregator/__init__.py b/fastNLP/modules/aggregator/__init__.py deleted file mode 100644 index a82138e7..00000000 --- a/fastNLP/modules/aggregator/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -__all__ = [ - "MaxPool", - "MaxPoolWithMask", - "AvgPool", - - "MultiHeadAttention", -] - -from .pooling import MaxPool -from .pooling import MaxPoolWithMask -from .pooling import AvgPool -from .pooling import AvgPoolWithMask - -from .attention import MultiHeadAttention diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/encoder/attention.py similarity index 98% rename from fastNLP/modules/aggregator/attention.py rename to fastNLP/modules/encoder/attention.py index 2bee7f2e..c0ba598d 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/encoder/attention.py @@ -8,9 +8,9 @@ import torch import torch.nn.functional as F from torch import nn -from ..dropout import TimestepDropout +from fastNLP.modules.dropout import TimestepDropout -from ..utils import initial_parameter +from fastNLP.modules.utils import initial_parameter class DotAttention(nn.Module): diff --git a/fastNLP/modules/aggregator/pooling.py b/fastNLP/modules/encoder/pooling.py similarity index 100% rename from fastNLP/modules/aggregator/pooling.py rename to fastNLP/modules/encoder/pooling.py diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index 698ff95c..d6bf2f1e 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -3,7 +3,7 @@ __all__ = [ ] from torch import nn -from ..aggregator.attention import MultiHeadAttention +from fastNLP.modules.encoder.attention import MultiHeadAttention from ..dropout import TimestepDropout diff --git a/reproduction/coreference_resolution/readme.md b/reproduction/coreference_resolution/README.md similarity index 85% rename from reproduction/coreference_resolution/readme.md rename to reproduction/coreference_resolution/README.md index 67d8cdc7..7cbcd052 100644 --- a/reproduction/coreference_resolution/readme.md +++ b/reproduction/coreference_resolution/README.md @@ -11,7 +11,7 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 由于版权问题,本文无法提供数据集的下载,请自行下载。 原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 -代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 +代码实现采用了论文作者Lee的预处理方法,具体细节参见[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 处理之后的数据集为json格式,例子: ``` { @@ -25,12 +25,12 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 ### embedding 数据集下载 [turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) -[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) +[glove embedding](https://nlp.stanford.edu/data/glove.840B.300d.zip) ## 运行 -```python +```shell # 训练代码 CUDA_VISIBLE_DEVICES=0 python train.py # 测试代码 @@ -39,9 +39,9 @@ CUDA_VISIBLE_DEVICES=0 python valid.py ## 结果 原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 -其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 +其中AllenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 -在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 +在与AllenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 ## 问题 diff --git a/reproduction/matching/README.md b/reproduction/matching/README.md index 056b0212..7b4997ea 100644 --- a/reproduction/matching/README.md +++ b/reproduction/matching/README.md @@ -2,7 +2,7 @@ 这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). 复现的模型有(按论文发表时间顺序排序): -- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). +- CNTN:[模型代码](model/cntn.py); [训练代码](matching_cntn.py). 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). - ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). @@ -21,7 +21,7 @@ model name | SNLI | MNLI | RTE | QNLI | Quora :---: | :---: | :---: | :---: | :---: | :---: -CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | +CNTN [代码](model/cntn.py); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 77.79 vs - | 63.29/63.16(dev) vs - | 57.04(dev) vs - | 62.38(dev) vs - | - | ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | @@ -44,7 +44,7 @@ Performance on Test set: model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large :---: | :---: | :---: | :---: | :---: | :---: | :---: -__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 +__performance__ | 77.79 | 88.13 | - | 87.9 | 90.6 | 91.16 ## MNLI [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) @@ -60,7 +60,7 @@ Performance on Test set(matched/mismatched): model name | CNTN | ESIM | DIIN | MwAN | BERT-Base :---: | :---: | :---: | :---: | :---: | :---: | -__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | +__performance__ | 63.29/63.16(dev) | 77.78/76.49 | - | 77.3/76.7(dev) | - | ## RTE @@ -92,7 +92,7 @@ Performance on __Dev__ set: model name | CNTN | ESIM | DIIN | MwAN | BERT :---: | :---: | :---: | :---: | :---: | :---: -__performance__ | - | 76.97 | - | 74.6 | - +__performance__ | 62.38 | 76.97 | - | 74.6 | - ## Quora diff --git a/requirements.txt b/requirements.txt index f8f7a951..90b67f2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ torch>=1.0.0 tqdm>=4.28.1 nltk>=3.4.1 requests +spacy +h5py