@@ -26,8 +26,8 @@ class CRLoader(JsonLoader): | |||||
super().__init__(fields, dropna) | super().__init__(fields, dropna) | ||||
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | # self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)} | ||||
# TODO check 1 | # TODO check 1 | ||||
self.fields = {"doc_key": "raw_key", "speakers": "raw_speakers", "clusters": "raw_clusters", | |||||
"sentences": "raw_words"} | |||||
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), | |||||
"sentences": Const.RAW_WORDS(3)} | |||||
def _load(self, path): | def _load(self, path): | ||||
""" | """ | ||||
@@ -46,10 +46,10 @@ class CoreferencePipe(Pipe): | |||||
:return: | :return: | ||||
""" | """ | ||||
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | ||||
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name="raw_words") | |||||
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3)) | |||||
vocab.build_vocab() | vocab.build_vocab() | ||||
word2id = vocab.word2idx | word2id = vocab.word2idx | ||||
data_bundle.set_vocab(vocab,"vocab") | |||||
data_bundle.set_vocab(vocab,Const.INPUT) | |||||
if self.config.char_path: | if self.config.char_path: | ||||
char_dict = get_char_dict(self.config.char_path) | char_dict = get_char_dict(self.config.char_path) | ||||
else: | else: | ||||
@@ -65,14 +65,14 @@ class CoreferencePipe(Pipe): | |||||
for name, ds in data_bundle.datasets.items(): | for name, ds in data_bundle.datasets.items(): | ||||
# genre | # genre | ||||
ds.apply(lambda x: genres[x["raw_key"][:2]], new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda x: genres[x[Const.RAW_WORDS(0)][:2]], new_field_name=Const.INPUTS(0)) | |||||
# speaker_ids_np | # speaker_ids_np | ||||
ds.apply(lambda x: speaker2numpy(x["raw_speakers"], self.config.max_sentences, is_train=name == 'train'), | |||||
ds.apply(lambda x: speaker2numpy(x[Const.RAW_WORDS(1)], self.config.max_sentences, is_train=name == 'train'), | |||||
new_field_name=Const.INPUTS(1)) | new_field_name=Const.INPUTS(1)) | ||||
# sentences | # sentences | ||||
ds.rename_field("raw_words",Const.INPUTS(2)) | |||||
ds.rename_field(Const.RAW_WORDS(3),Const.INPUTS(2)) | |||||
# doc_np | # doc_np | ||||
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), | ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter), | ||||
@@ -88,7 +88,7 @@ class CoreferencePipe(Pipe): | |||||
new_field_name=Const.INPUT_LEN) | new_field_name=Const.INPUT_LEN) | ||||
# clusters | # clusters | ||||
ds.rename_field("raw_clusters", Const.TARGET) | |||||
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET) | |||||
ds.set_ignore_type(Const.TARGET) | ds.set_ignore_type(Const.TARGET) | ||||
@@ -8,6 +8,7 @@ from fastNLP.core.callback import Callback, GradientClipCallback | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.io.pipe.coreference import CoreferencePipe | from fastNLP.io.pipe.coreference import CoreferencePipe | ||||
from fastNLP.core.const import Const | |||||
from reproduction.coreference_resolution.model.config import Config | from reproduction.coreference_resolution.model.config import Config | ||||
from reproduction.coreference_resolution.model.model_re import Model | from reproduction.coreference_resolution.model.model_re import Model | ||||
@@ -45,7 +46,7 @@ if __name__ == "__main__": | |||||
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), | print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), | ||||
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) | "\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) | ||||
# print(data_info) | # print(data_info) | ||||
model = Model(data_bundle.get_vocab("vocab"), config) | |||||
model = Model(data_bundle.get_vocab(Const.INPUT), config) | |||||
print(model) | print(model) | ||||
loss = SoftmaxLoss() | loss = SoftmaxLoss() | ||||
@@ -17,7 +17,7 @@ if __name__=='__main__': | |||||
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) | {'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) | ||||
metirc = CRMetric() | metirc = CRMetric() | ||||
model = torch.load(args.path) | model = torch.load(args.path) | ||||
tester = Tester(bundle.datasets['test'],model,metirc,batch_size=1,device="cuda:0") | |||||
tester = Tester(bundle.get_dataset("test"),model,metirc,batch_size=1,device="cuda:0") | |||||
tester.test() | tester.test() | ||||
print('test over') | print('test over') | ||||