|
@@ -20,7 +20,22 @@ TAG_UNK = "X" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExtCNNDMPipe(Pipe):
|
|
|
class ExtCNNDMPipe(Pipe):
|
|
|
|
|
|
"""
|
|
|
|
|
|
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构:
|
|
|
|
|
|
|
|
|
|
|
|
.. csv-table::
|
|
|
|
|
|
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target"
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False):
|
|
|
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False):
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
:param vocab_size: int, 词表大小
|
|
|
|
|
|
:param vocab_path: str, 外部词表路径
|
|
|
|
|
|
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断
|
|
|
|
|
|
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断
|
|
|
|
|
|
:param domain: bool, 是否需要建立domain词表
|
|
|
|
|
|
"""
|
|
|
self.vocab_size = vocab_size
|
|
|
self.vocab_size = vocab_size
|
|
|
self.vocab_path = vocab_path
|
|
|
self.vocab_path = vocab_path
|
|
|
self.sent_max_len = sent_max_len
|
|
|
self.sent_max_len = sent_max_len
|
|
@@ -33,28 +48,34 @@ class ExtCNNDMPipe(Pipe): |
|
|
传入的DataSet应该具备如下的结构
|
|
|
传入的DataSet应该具备如下的结构
|
|
|
|
|
|
|
|
|
.. csv-table::
|
|
|
.. csv-table::
|
|
|
:header: "text", "summary", "label", "domain"
|
|
|
|
|
|
|
|
|
:header: "text", "summary", "label", "publication"
|
|
|
|
|
|
|
|
|
"I got 'new' tires from them and... ", "The 'new' tires...", [0, 1], "cnndm"
|
|
|
|
|
|
"Don't waste your time. We had two...", "Time is precious", [1], "cnndm"
|
|
|
|
|
|
"...", "...", [], "cnndm"
|
|
|
|
|
|
|
|
|
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm"
|
|
|
|
|
|
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm"
|
|
|
|
|
|
["..."], ["..."], [], "cnndm"
|
|
|
|
|
|
|
|
|
:param data_bundle:
|
|
|
:param data_bundle:
|
|
|
:return:
|
|
|
|
|
|
|
|
|
:return: 处理得到的数据包括
|
|
|
|
|
|
.. csv-table::
|
|
|
|
|
|
:header: "text_wd", "words", "seq_len", "target"
|
|
|
|
|
|
|
|
|
|
|
|
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0]
|
|
|
|
|
|
[["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0]
|
|
|
|
|
|
[[""],...,[""]], [[],...,[]], [], []
|
|
|
"""
|
|
|
"""
|
|
|
|
|
|
|
|
|
db.apply(lambda x: _lower_text(x['text']), new_field_name='text')
|
|
|
db.apply(lambda x: _lower_text(x['text']), new_field_name='text')
|
|
|
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
|
|
|
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
|
|
|
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
|
|
|
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
|
|
|
db.apply(lambda x: _split_list(x['summary']), new_field_name='summary_wd')
|
|
|
|
|
|
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
|
|
|
|
|
|
|
|
|
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
|
|
|
|
|
|
|
|
|
|
|
|
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
|
|
|
|
|
|
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask")
|
|
|
|
|
|
|
|
|
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name="pad_text_wd")
|
|
|
|
|
|
db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask")
|
|
|
|
|
|
# pad document
|
|
|
# pad document
|
|
|
db.apply(lambda x: _pad_doc(x["pad_text_wd"], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
|
|
|
|
|
|
db.apply(lambda x: _sent_mask(x["pad_text_wd"], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
|
|
|
|
|
|
db.apply(lambda x: _pad_label(x["flatten_label"], self.doc_max_timesteps), new_field_name=Const.TARGET)
|
|
|
|
|
|
|
|
|
db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
|
|
|
|
|
|
db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
|
|
|
|
|
|
db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
|
|
|
|
|
|
|
|
|
db = _drop_empty_instance(db, "label")
|
|
|
db = _drop_empty_instance(db, "label")
|
|
|
|
|
|
|
|
@@ -87,15 +108,15 @@ class ExtCNNDMPipe(Pipe): |
|
|
|
|
|
|
|
|
def process_from_file(self, paths=None):
|
|
|
def process_from_file(self, paths=None):
|
|
|
"""
|
|
|
"""
|
|
|
:param paths:
|
|
|
|
|
|
|
|
|
:param paths: dict or string
|
|
|
:return: DataBundle
|
|
|
:return: DataBundle
|
|
|
"""
|
|
|
"""
|
|
|
db = DataBundle()
|
|
|
db = DataBundle()
|
|
|
if isinstance(paths, dict):
|
|
|
if isinstance(paths, dict):
|
|
|
for key, value in paths.items():
|
|
|
for key, value in paths.items():
|
|
|
db.set_dataset(JsonLoader()._load(value), key)
|
|
|
|
|
|
|
|
|
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key)
|
|
|
else:
|
|
|
else:
|
|
|
db.set_dataset(JsonLoader()._load(paths), 'test')
|
|
|
|
|
|
|
|
|
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test')
|
|
|
self.process(db)
|
|
|
self.process(db)
|
|
|
for ds in db.datasets.values():
|
|
|
for ds in db.datasets.values():
|
|
|
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
|
|
|
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
|
|
|