Browse Source

1. add summary loader; 2. reorganize code in ExtCNNDMPipe; 3. reorganize test data and code for ExtCNNDMPipe;

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
a4babc04e2
7 changed files with 188 additions and 97 deletions
  1. +63
    -0
      fastNLP/io/loader/summarization.py
  2. +48
    -38
      fastNLP/io/pipe/summarization.py
  3. +4
    -0
      test/data_for_tests/io/cnndm/dev.label.jsonl
  4. +4
    -0
      test/data_for_tests/io/cnndm/test.label.jsonl
  5. +0
    -0
      test/data_for_tests/io/cnndm/train.cnndm.jsonl
  6. +0
    -0
      test/data_for_tests/io/cnndm/vocab
  7. +69
    -59
      test/io/pipe/test_summary.py

+ 63
- 0
fastNLP/io/loader/summarization.py View File

@@ -0,0 +1,63 @@
"""undocumented"""

__all__ = [
"ExtCNNDMLoader"
]

import os
from typing import Union, Dict

from ..data_bundle import DataBundle
from ..utils import check_loader_paths
from .json import JsonLoader


class ExtCNNDMLoader(JsonLoader):
"""
读取之后的DataSet中的field情况为

.. csv-table::
:header: "text", "summary", "label", "publication"

["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm"
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm"
["..."], ["..."], [], "cnndm"

"""

def __init__(self, fields=None):
fields = fields or {"text": None, "summary": None, "label": None, "publication": None}
super(ExtCNNDMLoader, self).__init__(fields=fields)

def load(self, paths: Union[str, Dict[str, str]] = None):
"""
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。

读取的field根据ExtCNNDMLoader初始化时传入的headers决定。

:param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl
test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe`
当中需要用到)。

:return: 返回 :class:`~fastNLP.io.DataBundle`
"""
if paths is None:
paths = self.download()
paths = check_loader_paths(paths)
if ('train' in paths) and ('test' not in paths):
paths['test'] = paths['train']
paths.pop('train')

datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets)
return data_bundle

def download(self):
"""
如果你使用了这个数据,请引用

https://arxiv.org/pdf/1506.03340.pdf
:return:
"""
output_dir = self._get_dataset_path('ext-cnndm')
return output_dir

+ 48
- 38
fastNLP/io/pipe/summarization.py View File

@@ -1,15 +1,14 @@
"""undocumented""" """undocumented"""
import os
import numpy as np import numpy as np
from .pipe import Pipe from .pipe import Pipe
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from ..loader.json import JsonLoader
from .utils import _drop_empty_instance
from ..loader.summarization import ExtCNNDMLoader
from ..data_bundle import DataBundle from ..data_bundle import DataBundle
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core.const import Const from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
from ...core._logger import logger
WORD_PAD = "[PAD]" WORD_PAD = "[PAD]"
@@ -18,7 +17,6 @@ DOMAIN_UNK = "X"
TAG_UNK = "X" TAG_UNK = "X"
class ExtCNNDMPipe(Pipe): class ExtCNNDMPipe(Pipe):
""" """
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: 对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构:
@@ -27,13 +25,13 @@ class ExtCNNDMPipe(Pipe):
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target"
""" """
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False):
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False):
""" """
:param vocab_size: int, 词表大小 :param vocab_size: int, 词表大小
:param vocab_path: str, 外部词表路径
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 :param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 :param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断
:param vocab_path: str, 外部词表路径
:param domain: bool, 是否需要建立domain词表 :param domain: bool, 是否需要建立domain词表
""" """
self.vocab_size = vocab_size self.vocab_size = vocab_size
@@ -42,8 +40,7 @@ class ExtCNNDMPipe(Pipe):
self.doc_max_timesteps = doc_max_timesteps self.doc_max_timesteps = doc_max_timesteps
self.domain = domain self.domain = domain
def process(self, db: DataBundle):
def process(self, data_bundle: DataBundle):
""" """
传入的DataSet应该具备如下的结构 传入的DataSet应该具备如下的结构
@@ -64,24 +61,28 @@ class ExtCNNDMPipe(Pipe):
[[""],...,[""]], [[],...,[]], [], [] [[""],...,[""]], [[],...,[]], [], []
""" """
db.apply(lambda x: _lower_text(x['text']), new_field_name='text')
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
if self.vocab_path is None:
error_msg = 'vocab file is not defined!'
logger.error(error_msg)
raise RuntimeError(error_msg)
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text')
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") # db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask")
# pad document # pad document
db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
db = _drop_empty_instance(db, "label")
data_bundle = _drop_empty_instance(data_bundle, "label")
# set input and target # set input and target
db.set_input(Const.INPUT, Const.INPUT_LEN)
db.set_target(Const.TARGET, Const.INPUT_LEN)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN)
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) # print("[INFO] Load existing vocab from %s!" % self.vocab_path)
word_list = [] word_list = []
@@ -96,47 +97,52 @@ class ExtCNNDMPipe(Pipe):
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.add_word_lst(word_list) vocabs.add_word_lst(word_list)
vocabs.build_vocab() vocabs.build_vocab()
db.set_vocab(vocabs, "vocab")
data_bundle.set_vocab(vocabs, "vocab")
if self.domain == True:
if self.domain is True:
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
domaindict.from_dataset(db.get_dataset("train"), field_name="publication")
db.set_vocab(domaindict, "domain")
return db
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication")
data_bundle.set_vocab(domaindict, "domain")
return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """
:param paths: dict or string
:return: DataBundle
"""
db = DataBundle()
if isinstance(paths, dict):
for key, value in paths.items():
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key)
else:
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test')
self.process(db)
:param paths: dict or string
:return: DataBundle
"""
loader = ExtCNNDMLoader()
if self.vocab_path is None:
if paths is None:
paths = loader.download()
if not os.path.isdir(paths):
error_msg = 'vocab file is not defined!'
logger.error(error_msg)
raise RuntimeError(error_msg)
self.vocab_path = os.path.join(paths, 'vocab')
db = loader.load(paths=paths)
db = self.process(db)
for ds in db.datasets.values(): for ds in db.datasets.values():
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
return db return db
def _lower_text(text_list): def _lower_text(text_list):
return [text.lower() for text in text_list] return [text.lower() for text in text_list]
def _split_list(text_list): def _split_list(text_list):
return [text.split() for text in text_list] return [text.split() for text in text_list]
def _convert_label(label, sent_len): def _convert_label(label, sent_len):
np_label = np.zeros(sent_len, dtype=int) np_label = np.zeros(sent_len, dtype=int)
if label != []: if label != []:
np_label[np.array(label)] = 1 np_label[np.array(label)] = 1
return np_label.tolist() return np_label.tolist()
def _pad_sent(text_wd, sent_max_len): def _pad_sent(text_wd, sent_max_len):
pad_text_wd = [] pad_text_wd = []
for sent_wd in text_wd: for sent_wd in text_wd:
@@ -148,6 +154,7 @@ def _pad_sent(text_wd, sent_max_len):
pad_text_wd.append(sent_wd) pad_text_wd.append(sent_wd)
return pad_text_wd return pad_text_wd
def _token_mask(text_wd, sent_max_len): def _token_mask(text_wd, sent_max_len):
token_mask_list = [] token_mask_list = []
for sent_wd in text_wd: for sent_wd in text_wd:
@@ -159,6 +166,7 @@ def _token_mask(text_wd, sent_max_len):
token_mask_list.append(mask) token_mask_list.append(mask)
return token_mask_list return token_mask_list
def _pad_label(label, doc_max_timesteps): def _pad_label(label, doc_max_timesteps):
text_len = len(label) text_len = len(label)
if text_len < doc_max_timesteps: if text_len < doc_max_timesteps:
@@ -167,6 +175,7 @@ def _pad_label(label, doc_max_timesteps):
pad_label = label[:doc_max_timesteps] pad_label = label[:doc_max_timesteps]
return pad_label return pad_label
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): def _pad_doc(text_wd, sent_max_len, doc_max_timesteps):
text_len = len(text_wd) text_len = len(text_wd)
if text_len < doc_max_timesteps: if text_len < doc_max_timesteps:
@@ -176,6 +185,7 @@ def _pad_doc(text_wd, sent_max_len, doc_max_timesteps):
pad_text = text_wd[:doc_max_timesteps] pad_text = text_wd[:doc_max_timesteps]
return pad_text return pad_text
def _sent_mask(text_wd, doc_max_timesteps): def _sent_mask(text_wd, doc_max_timesteps):
text_len = len(text_wd) text_len = len(text_wd)
if text_len < doc_max_timesteps: if text_len < doc_max_timesteps:


+ 4
- 0
test/data_for_tests/io/cnndm/dev.label.jsonl
File diff suppressed because it is too large
View File


+ 4
- 0
test/data_for_tests/io/cnndm/test.label.jsonl
File diff suppressed because it is too large
View File


test/data_for_tests/cnndm.jsonl → test/data_for_tests/io/cnndm/train.cnndm.jsonl View File


test/data_for_tests/cnndm.vocab → test/data_for_tests/io/cnndm/vocab View File


test/io/pipe/test_extcnndm.py → test/io/pipe/test_summary.py View File

@@ -1,59 +1,69 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# __author__="Danqing Wang"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import unittest
import os
# import sys
#
# sys.path.append("../../../")
from fastNLP.io import DataBundle
from fastNLP.io.pipe.summarization import ExtCNNDMPipe
class TestRunExtCNNDMPipe(unittest.TestCase):
def test_load(self):
data_set_dict = {
'CNNDM': {"train": 'test/data_for_tests/cnndm.jsonl'},
}
vocab_size = 100000
VOCAL_FILE = 'test/data_for_tests/cnndm.vocab'
sent_max_len = 100
doc_max_timesteps = 50
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps)
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,
domain=True)
for k, v in data_set_dict.items():
db = dbPipe.process_from_file(v)
db2 = dbPipe2.process_from_file(v)
# print(db2.get_dataset("train"))
self.assertTrue(isinstance(db, DataBundle))
self.assertTrue(isinstance(db2, DataBundle))
#!/usr/bin/python
# -*- coding: utf-8 -*-

# __author__="Danqing Wang"

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import unittest
import os

from fastNLP.io import DataBundle
from fastNLP.io.pipe.summarization import ExtCNNDMPipe


class TestRunExtCNNDMPipe(unittest.TestCase):

def test_load(self):
data_dir = 'test/data_for_tests/io/cnndm'
vocab_size = 100000
VOCAL_FILE = 'test/data_for_tests/io/cnndm/vocab'
sent_max_len = 100
doc_max_timesteps = 50
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps)
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,
domain=True)
db = dbPipe.process_from_file(data_dir)
db2 = dbPipe2.process_from_file(data_dir)

self.assertTrue(isinstance(db, DataBundle))
self.assertTrue(isinstance(db2, DataBundle))

dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,
domain=True)
db3 = dbPipe3.process_from_file(data_dir)
self.assertTrue(isinstance(db3, DataBundle))

with self.assertRaises(RuntimeError):
dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps)
db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl'))

dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,)
db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl'))
self.assertIsInstance(db5, DataBundle)


Loading…
Cancel
Save