|
|
@@ -1,15 +1,14 @@ |
|
|
|
"""undocumented"""
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from .pipe import Pipe
|
|
|
|
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
|
|
|
|
from ..loader.json import JsonLoader
|
|
|
|
from .utils import _drop_empty_instance
|
|
|
|
from ..loader.summarization import ExtCNNDMLoader
|
|
|
|
from ..data_bundle import DataBundle
|
|
|
|
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
|
|
|
|
from ...core.const import Const
|
|
|
|
from ...core.dataset import DataSet
|
|
|
|
from ...core.instance import Instance
|
|
|
|
from ...core.vocabulary import Vocabulary
|
|
|
|
from ...core._logger import logger
|
|
|
|
|
|
|
|
|
|
|
|
WORD_PAD = "[PAD]"
|
|
|
@@ -18,7 +17,6 @@ DOMAIN_UNK = "X" |
|
|
|
TAG_UNK = "X"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExtCNNDMPipe(Pipe):
|
|
|
|
"""
|
|
|
|
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构:
|
|
|
@@ -27,13 +25,13 @@ class ExtCNNDMPipe(Pipe): |
|
|
|
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target"
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(self, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False):
|
|
|
|
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False):
|
|
|
|
"""
|
|
|
|
|
|
|
|
:param vocab_size: int, 词表大小
|
|
|
|
:param vocab_path: str, 外部词表路径
|
|
|
|
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断
|
|
|
|
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断
|
|
|
|
:param vocab_path: str, 外部词表路径
|
|
|
|
:param domain: bool, 是否需要建立domain词表
|
|
|
|
"""
|
|
|
|
self.vocab_size = vocab_size
|
|
|
@@ -42,8 +40,7 @@ class ExtCNNDMPipe(Pipe): |
|
|
|
self.doc_max_timesteps = doc_max_timesteps
|
|
|
|
self.domain = domain
|
|
|
|
|
|
|
|
|
|
|
|
def process(self, db: DataBundle):
|
|
|
|
def process(self, data_bundle: DataBundle):
|
|
|
|
"""
|
|
|
|
传入的DataSet应该具备如下的结构
|
|
|
|
|
|
|
@@ -64,24 +61,28 @@ class ExtCNNDMPipe(Pipe): |
|
|
|
[[""],...,[""]], [[],...,[]], [], []
|
|
|
|
"""
|
|
|
|
|
|
|
|
db.apply(lambda x: _lower_text(x['text']), new_field_name='text')
|
|
|
|
db.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
|
|
|
|
db.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
|
|
|
|
db.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
|
|
|
|
if self.vocab_path is None:
|
|
|
|
error_msg = 'vocab file is not defined!'
|
|
|
|
logger.error(error_msg)
|
|
|
|
raise RuntimeError(error_msg)
|
|
|
|
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text')
|
|
|
|
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
|
|
|
|
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd')
|
|
|
|
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET)
|
|
|
|
|
|
|
|
db.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
|
|
|
|
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT)
|
|
|
|
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask")
|
|
|
|
|
|
|
|
# pad document
|
|
|
|
db.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
|
|
|
|
db.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
|
|
|
|
db.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
|
|
|
|
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT)
|
|
|
|
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN)
|
|
|
|
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET)
|
|
|
|
|
|
|
|
db = _drop_empty_instance(db, "label")
|
|
|
|
data_bundle = _drop_empty_instance(data_bundle, "label")
|
|
|
|
|
|
|
|
# set input and target
|
|
|
|
db.set_input(Const.INPUT, Const.INPUT_LEN)
|
|
|
|
db.set_target(Const.TARGET, Const.INPUT_LEN)
|
|
|
|
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
|
|
|
|
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN)
|
|
|
|
|
|
|
|
# print("[INFO] Load existing vocab from %s!" % self.vocab_path)
|
|
|
|
word_list = []
|
|
|
@@ -96,47 +97,52 @@ class ExtCNNDMPipe(Pipe): |
|
|
|
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
|
|
|
|
vocabs.add_word_lst(word_list)
|
|
|
|
vocabs.build_vocab()
|
|
|
|
db.set_vocab(vocabs, "vocab")
|
|
|
|
data_bundle.set_vocab(vocabs, "vocab")
|
|
|
|
|
|
|
|
if self.domain == True:
|
|
|
|
if self.domain is True:
|
|
|
|
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
|
|
|
|
domaindict.from_dataset(db.get_dataset("train"), field_name="publication")
|
|
|
|
db.set_vocab(domaindict, "domain")
|
|
|
|
|
|
|
|
return db
|
|
|
|
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication")
|
|
|
|
data_bundle.set_vocab(domaindict, "domain")
|
|
|
|
|
|
|
|
return data_bundle
|
|
|
|
|
|
|
|
def process_from_file(self, paths=None):
|
|
|
|
"""
|
|
|
|
:param paths: dict or string
|
|
|
|
:return: DataBundle
|
|
|
|
"""
|
|
|
|
db = DataBundle()
|
|
|
|
if isinstance(paths, dict):
|
|
|
|
for key, value in paths.items():
|
|
|
|
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(value), key)
|
|
|
|
else:
|
|
|
|
db.set_dataset(JsonLoader(fields={"text":None, "summary":None, "label":None, "publication":None})._load(paths), 'test')
|
|
|
|
self.process(db)
|
|
|
|
:param paths: dict or string
|
|
|
|
:return: DataBundle
|
|
|
|
"""
|
|
|
|
loader = ExtCNNDMLoader()
|
|
|
|
if self.vocab_path is None:
|
|
|
|
if paths is None:
|
|
|
|
paths = loader.download()
|
|
|
|
if not os.path.isdir(paths):
|
|
|
|
error_msg = 'vocab file is not defined!'
|
|
|
|
logger.error(error_msg)
|
|
|
|
raise RuntimeError(error_msg)
|
|
|
|
self.vocab_path = os.path.join(paths, 'vocab')
|
|
|
|
db = loader.load(paths=paths)
|
|
|
|
db = self.process(db)
|
|
|
|
for ds in db.datasets.values():
|
|
|
|
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
|
|
|
|
|
|
|
|
return db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _lower_text(text_list):
|
|
|
|
return [text.lower() for text in text_list]
|
|
|
|
|
|
|
|
|
|
|
|
def _split_list(text_list):
|
|
|
|
return [text.split() for text in text_list]
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_label(label, sent_len):
|
|
|
|
np_label = np.zeros(sent_len, dtype=int)
|
|
|
|
if label != []:
|
|
|
|
np_label[np.array(label)] = 1
|
|
|
|
return np_label.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
def _pad_sent(text_wd, sent_max_len):
|
|
|
|
pad_text_wd = []
|
|
|
|
for sent_wd in text_wd:
|
|
|
@@ -148,6 +154,7 @@ def _pad_sent(text_wd, sent_max_len): |
|
|
|
pad_text_wd.append(sent_wd)
|
|
|
|
return pad_text_wd
|
|
|
|
|
|
|
|
|
|
|
|
def _token_mask(text_wd, sent_max_len):
|
|
|
|
token_mask_list = []
|
|
|
|
for sent_wd in text_wd:
|
|
|
@@ -159,6 +166,7 @@ def _token_mask(text_wd, sent_max_len): |
|
|
|
token_mask_list.append(mask)
|
|
|
|
return token_mask_list
|
|
|
|
|
|
|
|
|
|
|
|
def _pad_label(label, doc_max_timesteps):
|
|
|
|
text_len = len(label)
|
|
|
|
if text_len < doc_max_timesteps:
|
|
|
@@ -167,6 +175,7 @@ def _pad_label(label, doc_max_timesteps): |
|
|
|
pad_label = label[:doc_max_timesteps]
|
|
|
|
return pad_label
|
|
|
|
|
|
|
|
|
|
|
|
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps):
|
|
|
|
text_len = len(text_wd)
|
|
|
|
if text_len < doc_max_timesteps:
|
|
|
@@ -176,6 +185,7 @@ def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): |
|
|
|
pad_text = text_wd[:doc_max_timesteps]
|
|
|
|
return pad_text
|
|
|
|
|
|
|
|
|
|
|
|
def _sent_mask(text_wd, doc_max_timesteps):
|
|
|
|
text_len = len(text_wd)
|
|
|
|
if text_len < doc_max_timesteps:
|
|
|
|