Browse Source

1. add summary loader; 2. reorganize code in ExtCNNDMPipe; 3. reorganize test data and code for ExtCNNDMPipe;

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
a4babc04e2
7 changed files with 188 additions and 97 deletions
  1. +63
    -0
      fastNLP/io/loader/summarization.py
  2. +48
    -38
      fastNLP/io/pipe/summarization.py
  3. +4
    -0
      test/data_for_tests/io/cnndm/dev.label.jsonl
  4. +4
    -0
      test/data_for_tests/io/cnndm/test.label.jsonl
  5. +0
    -0
      test/data_for_tests/io/cnndm/train.cnndm.jsonl
  6. +0
    -0
      test/data_for_tests/io/cnndm/vocab
  7. +69
    -59
      test/io/pipe/test_summary.py

+ 63
- 0
fastNLP/io/loader/summarization.py View File

@@ -0,0 +1,63 @@
"""undocumented"""

__all__ = [
"ExtCNNDMLoader"
]

import os
from typing import Union, Dict

from ..data_bundle import DataBundle
from ..utils import check_loader_paths
from .json import JsonLoader


class ExtCNNDMLoader(JsonLoader):
"""
读取之后的DataSet中的field情况为

.. csv-table::
:header: "text", "summary", "label", "publication"

["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm"
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm"
["..."], ["..."], [], "cnndm"

"""

def __init__(self, fields=None):
fields = fields or {"text": None, "summary": None, "label": None, "publication": None}
super(ExtCNNDMLoader, self).__init__(fields=fields)

def load(self, paths: Union[str, Dict[str, str]] = None):
"""
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。

读取的field根据ExtCNNDMLoader初始化时传入的headers决定。

:param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl
test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe`
当中需要用到)。

:return: 返回 :class:`~fastNLP.io.DataBundle`
"""
if paths is None:
paths = self.download()
paths = check_loader_paths(paths)
if ('train' in paths) and ('test' not in paths):
paths['test'] = paths['train']
paths.pop('train')

datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets)
return data_bundle

def download(self):
"""
如果你使用了这个数据,请引用

https://arxiv.org/pdf/1506.03340.pdf
:return:
"""
output_dir = self._get_dataset_path('ext-cnndm')
return output_dir

+ 48
- 38
fastNLP/io/pipe/summarization.py View File

@@ -1,15 +1,14 @@
"""undocumented"""
import os
import numpy as np
from .pipe import Pipe
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from ..loader.json import JsonLoader
from .utils import _drop_empty_instance
from ..loader.summarization import ExtCNNDMLoader
from ..data_bundle import DataBundle
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.vocabulary import Vocabulary
from ...core._logger import logger
WORD_PAD = "[PAD]"
@@ -18,7 +17,6 @@ DOMAIN_UNK = "X"
TAG_UNK = "X"
class ExtCNNDMPipe(Pipe):
"""
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构:
@@ -27,13 +25,13 @@ class ExtCNNDMPipe(Pipe):
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target"
"""
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False):
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False):
"""
:param vocab_size: int, 词表大小
:param vocab_path: str, 外部词表路径
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断
:param vocab_path: str, 外部词表路径
:param domain: bool, 是否需要建立domain词表
"""
self.vocab_size = vocab_size
@@ -42,8 +40,7 @@ class ExtCNNDMPipe(Pipe):
self.doc_max_timesteps = doc_max_timesteps
self.domain = domain
def process(self, db: DataBundle):
def process(self, data_bundle: DataBundle):
"""
传入的DataSet应该具备如下的结构
@@ -64,24 +61,28 @@ class ExtCNNDMPipe(Pipe):
[[""],...,[""]], [[],...,[]], [], []
"""
db.apply(lambda x: _lower_text(x['text']), new_field_name='text')
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
if self.vocab_path is None:
error_msg = 'vocab file is not defined!'
logger.error(error_msg)
raise RuntimeError(error_msg)
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text')
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask")
# pad document
db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
db = _drop_empty_instance(db, "label")
data_bundle = _drop_empty_instance(data_bundle, "label")
# set input and target
db.set_input(Const.INPUT, Const.INPUT_LEN)
db.set_target(Const.TARGET, Const.INPUT_LEN)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN)
# print("[INFO] Load existing vocab from %s!" % self.vocab_path)
word_list = []
@@ -96,47 +97,52 @@ class ExtCNNDMPipe(Pipe):
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.add_word_lst(word_list)
vocabs.build_vocab()
db.set_vocab(vocabs, "vocab")
data_bundle.set_vocab(vocabs, "vocab")
if self.domain == True:
if self.domain is True:
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
domaindict.from_dataset(db.get_dataset("train"), field_name="publication")
db.set_vocab(domaindict, "domain")
return db
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication")
data_bundle.set_vocab(domaindict, "domain")
return data_bundle
def process_from_file(self, paths=None):
"""
:param paths: dict or string
:return: DataBundle
"""
db = DataBundle()
if isinstance(paths, dict):
for key, value in paths.items():
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key)
else:
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test')
self.process(db)
:param paths: dict or string
:return: DataBundle
"""
loader = ExtCNNDMLoader()
if self.vocab_path is None:
if paths is None:
paths = loader.download()
if not os.path.isdir(paths):
error_msg = 'vocab file is not defined!'
logger.error(error_msg)
raise RuntimeError(error_msg)
self.vocab_path = os.path.join(paths, 'vocab')
db = loader.load(paths=paths)
db = self.process(db)
for ds in db.datasets.values():
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
return db
def _lower_text(text_list):
return [text.lower() for text in text_list]
def _split_list(text_list):
return [text.split() for text in text_list]
def _convert_label(label, sent_len):
np_label = np.zeros(sent_len, dtype=int)
if label != []:
np_label[np.array(label)] = 1
return np_label.tolist()
def _pad_sent(text_wd, sent_max_len):
pad_text_wd = []
for sent_wd in text_wd:
@@ -148,6 +154,7 @@ def _pad_sent(text_wd, sent_max_len):
pad_text_wd.append(sent_wd)
return pad_text_wd
def _token_mask(text_wd, sent_max_len):
token_mask_list = []
for sent_wd in text_wd:
@@ -159,6 +166,7 @@ def _token_mask(text_wd, sent_max_len):
token_mask_list.append(mask)
return token_mask_list
def _pad_label(label, doc_max_timesteps):
text_len = len(label)
if text_len < doc_max_timesteps:
@@ -167,6 +175,7 @@ def _pad_label(label, doc_max_timesteps):
pad_label = label[:doc_max_timesteps]
return pad_label
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
@@ -176,6 +185,7 @@ def _pad_doc(text_wd, sent_max_len, doc_max_timesteps):
pad_text = text_wd[:doc_max_timesteps]
return pad_text
def _sent_mask(text_wd, doc_max_timesteps):
text_len = len(text_wd)
if text_len < doc_max_timesteps:


+ 4
- 0
test/data_for_tests/io/cnndm/dev.label.jsonl
File diff suppressed because it is too large
View File


+ 4
- 0
test/data_for_tests/io/cnndm/test.label.jsonl
File diff suppressed because it is too large
View File


test/data_for_tests/cnndm.jsonl → test/data_for_tests/io/cnndm/train.cnndm.jsonl View File


test/data_for_tests/cnndm.vocab → test/data_for_tests/io/cnndm/vocab View File


test/io/pipe/test_extcnndm.py → test/io/pipe/test_summary.py View File

@@ -1,59 +1,69 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# __author__="Danqing Wang"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import unittest
import os
# import sys
#
# sys.path.append("../../../")
from fastNLP.io import DataBundle
from fastNLP.io.pipe.summarization import ExtCNNDMPipe
class TestRunExtCNNDMPipe(unittest.TestCase):
def test_load(self):
data_set_dict = {
'CNNDM': {"train": 'test/data_for_tests/cnndm.jsonl'},
}
vocab_size = 100000
VOCAL_FILE = 'test/data_for_tests/cnndm.vocab'
sent_max_len = 100
doc_max_timesteps = 50
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps)
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,
domain=True)
for k, v in data_set_dict.items():
db = dbPipe.process_from_file(v)
db2 = dbPipe2.process_from_file(v)
# print(db2.get_dataset("train"))
self.assertTrue(isinstance(db, DataBundle))
self.assertTrue(isinstance(db2, DataBundle))
#!/usr/bin/python
# -*- coding: utf-8 -*-

# __author__="Danqing Wang"

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import unittest
import os

from fastNLP.io import DataBundle
from fastNLP.io.pipe.summarization import ExtCNNDMPipe


class TestRunExtCNNDMPipe(unittest.TestCase):

def test_load(self):
data_dir = 'test/data_for_tests/io/cnndm'
vocab_size = 100000
VOCAL_FILE = 'test/data_for_tests/io/cnndm/vocab'
sent_max_len = 100
doc_max_timesteps = 50
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps)
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,
domain=True)
db = dbPipe.process_from_file(data_dir)
db2 = dbPipe2.process_from_file(data_dir)

self.assertTrue(isinstance(db, DataBundle))
self.assertTrue(isinstance(db2, DataBundle))

dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,
domain=True)
db3 = dbPipe3.process_from_file(data_dir)
self.assertTrue(isinstance(db3, DataBundle))

with self.assertRaises(RuntimeError):
dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps)
db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl'))

dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,)
db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl'))
self.assertIsInstance(db5, DataBundle)


Loading…
Cancel
Save