|
|
@@ -22,21 +22,56 @@ class CoreferencePipe(Pipe): |
|
|
|
self.config = config |
|
|
|
|
|
|
|
def process(self, data_bundle: DataBundle): |
|
|
|
""" |
|
|
|
对load进来的数据进一步处理 |
|
|
|
原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters |
|
|
|
.. csv-table:: |
|
|
|
:header: "raw_key", "raw_speaker","raw_words","raw_clusters" |
|
|
|
|
|
|
|
"bc/cctv/00/cctv_0000_0", "[["Speaker#1", "Speaker#1"],[]]","[["I","am"],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" |
|
|
|
"bc/cctv/00/cctv_0000_1"", "[["Speaker#1", "Speaker#1"],[]]","[["He","is"],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" |
|
|
|
"[...]", "[...]","[...]","[...]" |
|
|
|
|
|
|
|
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: |
|
|
|
.. csv-table:: |
|
|
|
:header: "words1", "words2","words3","words4","chars","seq_len","target" |
|
|
|
|
|
|
|
"bc", "[[0,0],[1,1]]","[["I","am"],[]]",[[1,2],[]],[[[1],[2,3]],[]],[2,3],"[[[2,3],[6,7]],[[10,12],[20,22]]]" |
|
|
|
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]" |
|
|
|
|
|
|
|
|
|
|
|
:param data_bundle: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} |
|
|
|
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name=Const.INPUTS(2)) |
|
|
|
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name="raw_words") |
|
|
|
vocab.build_vocab() |
|
|
|
word2id = vocab.word2idx |
|
|
|
data_bundle.vocabs = {"vocab":vocab} |
|
|
|
char_dict = get_char_dict(self.config.char_path) |
|
|
|
data_bundle.set_vocab(vocab,"vocab") |
|
|
|
if self.config.char_path: |
|
|
|
char_dict = get_char_dict(self.config.char_path) |
|
|
|
else: |
|
|
|
char_set = set() |
|
|
|
for i,w in enumerate(word2id): |
|
|
|
if i < 2: |
|
|
|
continue |
|
|
|
for c in w: |
|
|
|
char_set.add(c) |
|
|
|
|
|
|
|
char_dict = collections.defaultdict(int) |
|
|
|
char_dict.update({c: i for i, c in enumerate(char_set)}) |
|
|
|
|
|
|
|
for name, ds in data_bundle.datasets.items(): |
|
|
|
# genre |
|
|
|
ds.apply(lambda x: genres[x[Const.INPUTS(0)][:2]], new_field_name=Const.INPUTS(0)) |
|
|
|
ds.apply(lambda x: genres[x["raw_key"][:2]], new_field_name=Const.INPUTS(0)) |
|
|
|
|
|
|
|
# speaker_ids_np |
|
|
|
ds.apply(lambda x: speaker2numpy(x[Const.INPUTS(1)], self.config.max_sentences, is_train=name == 'train'), |
|
|
|
ds.apply(lambda x: speaker2numpy(x["raw_speakers"], self.config.max_sentences, is_train=name == 'train'), |
|
|
|
new_field_name=Const.INPUTS(1)) |
|
|
|
|
|
|
|
# sentences |
|
|
|
ds.rename_field("raw_words",Const.INPUTS(2)) |
|
|
|
|
|
|
|
# doc_np |
|
|
|
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), |
|
|
|
self.config.max_sentences, is_train=name == 'train')[0], |
|
|
@@ -50,6 +85,9 @@ class CoreferencePipe(Pipe): |
|
|
|
self.config.max_sentences, is_train=name == 'train')[2], |
|
|
|
new_field_name=Const.INPUT_LEN) |
|
|
|
|
|
|
|
# clusters |
|
|
|
ds.rename_field("raw_clusters", Const.TARGET) |
|
|
|
|
|
|
|
|
|
|
|
ds.set_ignore_type(Const.TARGET) |
|
|
|
ds.set_padder(Const.TARGET, None) |
|
|
|