Browse Source

add doc for summarization pipe

tags/v0.4.10
Danqing Wang 5 years ago
parent
commit
8affc4a3ff
5 changed files with 48 additions and 200026 deletions
  1. +8
    -9
      fastNLP/io/loader/json.py
  2. +36
    -15
      fastNLP/io/pipe/summarization.py
  3. +1
    -1
      reproduction/Summarization/README.md
  4. +0
    -200000
      test/io/pipe/cnndm.vocab
  5. +3
    -1
      test/io/pipe/test_extcnndm.py

+ 8
- 9
fastNLP/io/loader/json.py View File

@@ -12,20 +12,19 @@ from ...core.instance import Instance


class JsonLoader(Loader): class JsonLoader(Loader):
""" """
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader`

读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象


:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
""" """


def __init__(self, fields=None, dropna=False): def __init__(self, fields=None, dropna=False):
"""
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""
super(JsonLoader, self).__init__() super(JsonLoader, self).__init__()
self.dropna = dropna self.dropna = dropna
self.fields = None self.fields = None


+ 36
- 15
fastNLP/io/pipe/summarization.py View File

@@ -20,7 +20,22 @@ TAG_UNK = "X"
class ExtCNNDMPipe(Pipe): class ExtCNNDMPipe(Pipe):
"""
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构:
.. csv-table::
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target"
"""
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False): def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False):
"""
:param vocab_size: int, 词表大小
:param vocab_path: str, 外部词表路径
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断
:param domain: bool, 是否需要建立domain词表
"""
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.vocab_path = vocab_path self.vocab_path = vocab_path
self.sent_max_len = sent_max_len self.sent_max_len = sent_max_len
@@ -33,28 +48,34 @@ class ExtCNNDMPipe(Pipe):
传入的DataSet应该具备如下的结构 传入的DataSet应该具备如下的结构
.. csv-table:: .. csv-table::
:header: "text", "summary", "label", "domain"
:header: "text", "summary", "label", "publication"
"I got 'new' tires from them and... ", "The 'new' tires...", [0, 1], "cnndm"
"Don't waste your time. We had two...", "Time is precious", [1], "cnndm"
"...", "...", [], "cnndm"
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm"
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm"
["..."], ["..."], [], "cnndm"
:param data_bundle: :param data_bundle:
:return:
:return: 处理得到的数据包括
.. csv-table::
:header: "text_wd", "words", "seq_len", "target"
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0]
[["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0]
[[""],...,[""]], [[],...,[]], [], []
""" """
db.apply(lambda x: _lower_text(x['text']), new_field_name='text') db.apply(lambda x: _lower_text(x['text']), new_field_name='text')
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
db.apply(lambda x: _split_list(x['summary']), new_field_name='summary_wd')
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask")
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name="pad_text_wd")
db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask")
# pad document # pad document
db.apply(lambda x: _pad_doc(x["pad_text_wd"], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
db.apply(lambda x: _sent_mask(x["pad_text_wd"], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
db.apply(lambda x: _pad_label(x["flatten_label"], self.doc_max_timesteps), new_field_name=Const.TARGET)
db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
db = _drop_empty_instance(db, "label") db = _drop_empty_instance(db, "label")
@@ -87,15 +108,15 @@ class ExtCNNDMPipe(Pipe):
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """
:param paths:
:param paths: dict or string
:return: DataBundle :return: DataBundle
""" """
db = DataBundle() db = DataBundle()
if isinstance(paths, dict): if isinstance(paths, dict):
for key, value in paths.items(): for key, value in paths.items():
db.set_dataset(JsonLoader()._load(value), key)
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key)
else: else:
db.set_dataset(JsonLoader()._load(paths), 'test')
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test')
self.process(db) self.process(db)
for ds in db.datasets.values(): for ds in db.datasets.values():
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)


+ 1
- 1
reproduction/Summarization/README.md View File

@@ -18,7 +18,7 @@ FastNLP中实现的模型包括:
这里提供的摘要任务数据集包括: 这里提供的摘要任务数据集包括:
- CNN/DailyMail
- CNN/DailyMail ([Get To The Point: Summarization with Pointer-Generator Networks](http://arxiv.org/abs/1704.04368))
- Newsroom - Newsroom
- The New York Times Annotated Corpus - The New York Times Annotated Corpus
- NYT - NYT


+ 0
- 200000
test/io/pipe/cnndm.vocab
File diff suppressed because it is too large
View File


+ 3
- 1
test/io/pipe/test_extcnndm.py View File

@@ -44,11 +44,13 @@ class TestRunExtCNNDMPipe(unittest.TestCase):
vocab_path=VOCAL_FILE, vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len, sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps, doc_max_timesteps=doc_max_timesteps,
domain=True)
domain=True)
for k, v in data_set_dict.items(): for k, v in data_set_dict.items():
db = dbPipe.process_from_file(v) db = dbPipe.process_from_file(v)
db2 = dbPipe2.process_from_file(v) db2 = dbPipe2.process_from_file(v)
# print(db2.get_dataset("train"))
self.assertTrue(isinstance(db, DataBundle)) self.assertTrue(isinstance(db, DataBundle))
self.assertTrue(isinstance(db2, DataBundle)) self.assertTrue(isinstance(db2, DataBundle))


Loading…
Cancel
Save