From 5bbfb92a300d8d9aeba7f45ae4e2bf8dad19fcb4 Mon Sep 17 00:00:00 2001 From: xxliu Date: Fri, 6 Sep 2019 13:08:57 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E8=A7=84=E8=8C=83=E4=BB=A5?= =?UTF-8?q?=E5=8F=8A=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E8=B7=AF=E5=BE=84=E4=BB=A5=E5=8C=B9=E9=85=8Dgithub=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- reproduction/coreference_resolution/train.py | 4 ++-- test/io/loader/test_coreference_loader.py | 2 +- test/io/pipe/test_coreference.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/reproduction/coreference_resolution/train.py b/reproduction/coreference_resolution/train.py index c91f7109..cd4b65a5 100644 --- a/reproduction/coreference_resolution/train.py +++ b/reproduction/coreference_resolution/train.py @@ -45,7 +45,7 @@ if __name__ == "__main__": print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), "\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) # print(data_info) - model = Model(data_bundle.vocabs['vocab'], config) + model = Model(data_bundle.get_vocab("vocab"), config) print(model) loss = SoftmaxLoss() @@ -60,7 +60,7 @@ if __name__ == "__main__": loss=loss, metrics=metric, check_code_level=-1, sampler=None, batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, optimizer=optim, - save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save', + save_path= None, callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) print() diff --git a/test/io/loader/test_coreference_loader.py b/test/io/loader/test_coreference_loader.py index 48551f3e..d827e947 100644 --- a/test/io/loader/test_coreference_loader.py +++ b/test/io/loader/test_coreference_loader.py @@ -4,7 +4,7 @@ import unittest class TestCR(unittest.TestCase): def test_load(self): - test_root = "../../data_for_tests/coreference/" + test_root = "test/data_for_tests/coreference/" train_path = test_root+"coreference_train.json" dev_path = test_root+"coreference_dev.json" test_path = test_root+"coreference_test.json" diff --git a/test/io/pipe/test_coreference.py b/test/io/pipe/test_coreference.py index 1c53f2b0..517be993 100644 --- a/test/io/pipe/test_coreference.py +++ b/test/io/pipe/test_coreference.py @@ -11,7 +11,7 @@ class TestCR(unittest.TestCase): char_path = None config = Config() - file_root_path = "../../data_for_tests/coreference/" + file_root_path = "test/data_for_tests/coreference/" train_path = file_root_path + "coreference_train.json" dev_path = file_root_path + "coreference_dev.json" test_path = file_root_path + "coreference_test.json"