Browse Source

undocumented

tags/v0.4.10
xxliu 5 years ago
parent
commit
b015cc149c
7 changed files with 14 additions and 16 deletions
  1. +2
    -2
      fastNLP/io/loader/coreference.py
  2. +6
    -6
      fastNLP/io/pipe/coreference.py
  3. +2
    -1
      reproduction/coreference_resolution/train.py
  4. +1
    -1
      reproduction/coreference_resolution/valid.py
  5. +1
    -2
      test/data_for_tests/coreference/coreference_dev.json
  6. +1
    -2
      test/data_for_tests/coreference/coreference_test.json
  7. +1
    -2
      test/data_for_tests/coreference/coreference_train.json

+ 2
- 2
fastNLP/io/loader/coreference.py View File

@@ -26,8 +26,8 @@ class CRLoader(JsonLoader):
super().__init__(fields, dropna)
# self.fields = {"doc_key":Const.INPUTS(0),"speakers":Const.INPUTS(1),"clusters":Const.TARGET,"sentences":Const.INPUTS(2)}
# TODO check 1
self.fields = {"doc_key": "raw_key", "speakers": "raw_speakers", "clusters": "raw_clusters",
"sentences": "raw_words"}
self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2),
"sentences": Const.RAW_WORDS(3)}

def _load(self, path):
"""


+ 6
- 6
fastNLP/io/pipe/coreference.py View File

@@ -46,10 +46,10 @@ class CoreferencePipe(Pipe):
:return:
"""
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])}
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name="raw_words")
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name= Const.RAW_WORDS(3))
vocab.build_vocab()
word2id = vocab.word2idx
data_bundle.set_vocab(vocab,"vocab")
data_bundle.set_vocab(vocab,Const.INPUT)
if self.config.char_path:
char_dict = get_char_dict(self.config.char_path)
else:
@@ -65,14 +65,14 @@ class CoreferencePipe(Pipe):

for name, ds in data_bundle.datasets.items():
# genre
ds.apply(lambda x: genres[x["raw_key"][:2]], new_field_name=Const.INPUTS(0))
ds.apply(lambda x: genres[x[Const.RAW_WORDS(0)][:2]], new_field_name=Const.INPUTS(0))

# speaker_ids_np
ds.apply(lambda x: speaker2numpy(x["raw_speakers"], self.config.max_sentences, is_train=name == 'train'),
ds.apply(lambda x: speaker2numpy(x[Const.RAW_WORDS(1)], self.config.max_sentences, is_train=name == 'train'),
new_field_name=Const.INPUTS(1))

# sentences
ds.rename_field("raw_words",Const.INPUTS(2))
ds.rename_field(Const.RAW_WORDS(3),Const.INPUTS(2))

# doc_np
ds.apply(lambda x: doc2numpy(x[Const.INPUTS(2)], word2id, char_dict, max(self.config.filter),
@@ -88,7 +88,7 @@ class CoreferencePipe(Pipe):
new_field_name=Const.INPUT_LEN)

# clusters
ds.rename_field("raw_clusters", Const.TARGET)
ds.rename_field(Const.RAW_WORDS(2), Const.TARGET)


ds.set_ignore_type(Const.TARGET)


+ 2
- 1
reproduction/coreference_resolution/train.py View File

@@ -8,6 +8,7 @@ from fastNLP.core.callback import Callback, GradientClipCallback
from fastNLP.core.trainer import Trainer

from fastNLP.io.pipe.coreference import CoreferencePipe
from fastNLP.core.const import Const

from reproduction.coreference_resolution.model.config import Config
from reproduction.coreference_resolution.model.model_re import Model
@@ -45,7 +46,7 @@ if __name__ == "__main__":
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))),
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test'))))
# print(data_info)
model = Model(data_bundle.get_vocab("vocab"), config)
model = Model(data_bundle.get_vocab(Const.INPUT), config)
print(model)

loss = SoftmaxLoss()


+ 1
- 1
reproduction/coreference_resolution/valid.py View File

@@ -17,7 +17,7 @@ if __name__=='__main__':
{'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path})
metirc = CRMetric()
model = torch.load(args.path)
tester = Tester(bundle.datasets['test'],model,metirc,batch_size=1,device="cuda:0")
tester = Tester(bundle.get_dataset("test"),model,metirc,batch_size=1,device="cuda:0")
tester.test()
print('test over')



+ 1
- 2
test/data_for_tests/coreference/coreference_dev.json
File diff suppressed because it is too large
View File


+ 1
- 2
test/data_for_tests/coreference/coreference_test.json
File diff suppressed because it is too large
View File


+ 1
- 2
test/data_for_tests/coreference/coreference_train.json
File diff suppressed because it is too large
View File


Loading…
Cancel
Save