|
|
@@ -46,10 +46,10 @@ class CoreferencePipe(Pipe): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} |
|
|
|
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name="raw_words") |
|
|
|
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3)) |
|
|
|
vocab.build_vocab() |
|
|
|
word2id = vocab.word2idx |
|
|
|
data_bundle.set_vocab(vocab,"vocab") |
|
|
|
data_bundle.set_vocab(vocab,Const.INPUT) |
|
|
|
if self.config.char_path: |
|
|
|
char_dict = get_char_dict(self.config.char_path) |
|
|
|
else: |
|
|
@@ -65,14 +65,14 @@ class CoreferencePipe(Pipe): |
|
|
|
|
|
|
|
for name, ds in data_bundle.datasets.items(): |
|
|
|
# genre |
|
|
|
ds.apply(lambda x: genres[x["raw_key"][:2]], new_field_name=Const.INPUTS(0)) |
|
|
|
ds.apply(lambda x: genres[x[Const.RAW_WORDS(0)][:2]], new_field_name=Const.INPUTS(0)) |
|
|
|
|
|
|
|
# speaker_ids_np |
|
|
|
ds.apply(lambda x: speaker2numpy(x["raw_speakers"], self.config.max_sentences, is_train=name == 'train'), |
|
|
|
ds.apply(lambda x: speaker2numpy(x[Const.RAW_WORDS(1)], self.config.max_sentences, is_train=name == 'train'), |
|
|
|
new_field_name=Const.INPUTS(1)) |
|
|
|
|
|
|
|
# sentences |
|
|
|
ds.rename_field("raw_words",Const.INPUTS(2)) |
|
|
|
ds.rename_field(Const.RAW_WORDS(3),Const.INPUTS(2)) |
|
|
|
|
|
|
|
# doc_np |
|
|
|
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), |
|
|
@@ -88,7 +88,7 @@ class CoreferencePipe(Pipe): |
|
|
|
new_field_name=Const.INPUT_LEN) |
|
|
|
|
|
|
|
# clusters |
|
|
|
ds.rename_field("raw_clusters", Const.TARGET) |
|
|
|
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET) |
|
|
|
|
|
|
|
|
|
|
|
ds.set_ignore_type(Const.TARGET) |
|
|
|