Browse Source

代码规范以及修改测试文件路径以匹配github文件路径

tags/v0.4.10
xxliu 5 years ago
parent
commit
5bbfb92a30
3 changed files with 4 additions and 4 deletions
  1. +2
    -2
      reproduction/coreference_resolution/train.py
  2. +1
    -1
      test/io/loader/test_coreference_loader.py
  3. +1
    -1
      test/io/pipe/test_coreference.py

+ 2
- 2
reproduction/coreference_resolution/train.py View File

@@ -45,7 +45,7 @@ if __name__ == "__main__":
print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))), print("数据集划分:\ntrain:", str(len(data_bundle.get_dataset("train"))),
"\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test')))) "\ndev:" + str(len(data_bundle.get_dataset("dev"))) + "\ntest:" + str(len(data_bundle.get_dataset('test'))))
# print(data_info) # print(data_info)
model = Model(data_bundle.vocabs['vocab'], config)
model = Model(data_bundle.get_vocab("vocab"), config)
print(model) print(model)


loss = SoftmaxLoss() loss = SoftmaxLoss()
@@ -60,7 +60,7 @@ if __name__ == "__main__":
loss=loss, metrics=metric, check_code_level=-1, sampler=None, loss=loss, metrics=metric, check_code_level=-1, sampler=None,
batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch,
optimizer=optim, optimizer=optim,
save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save',
save_path= None,
callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)])
print() print()




+ 1
- 1
test/io/loader/test_coreference_loader.py View File

@@ -4,7 +4,7 @@ import unittest
class TestCR(unittest.TestCase): class TestCR(unittest.TestCase):
def test_load(self): def test_load(self):


test_root = "../../data_for_tests/coreference/"
test_root = "test/data_for_tests/coreference/"
train_path = test_root+"coreference_train.json" train_path = test_root+"coreference_train.json"
dev_path = test_root+"coreference_dev.json" dev_path = test_root+"coreference_dev.json"
test_path = test_root+"coreference_test.json" test_path = test_root+"coreference_test.json"


+ 1
- 1
test/io/pipe/test_coreference.py View File

@@ -11,7 +11,7 @@ class TestCR(unittest.TestCase):
char_path = None char_path = None
config = Config() config = Config()


file_root_path = "../../data_for_tests/coreference/"
file_root_path = "test/data_for_tests/coreference/"
train_path = file_root_path + "coreference_train.json" train_path = file_root_path + "coreference_train.json"
dev_path = file_root_path + "coreference_dev.json" dev_path = file_root_path + "coreference_dev.json"
test_path = file_root_path + "coreference_test.json" test_path = file_root_path + "coreference_test.json"


Loading…
Cancel
Save