From c73f327a3efd8419c78c22cc7067cc84e1e8a73d Mon Sep 17 00:00:00 2001 From: yunfan Date: Wed, 25 Sep 2019 19:49:30 +0800 Subject: [PATCH] [bugfix] fix test cases --- fastNLP/core/callback.py | 4 ++-- test/core/test_dist_trainer.py | 2 +- test/models/test_cnn_text_classification.py | 24 +++++++++++++++------ 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 7c35b31d..6231fc18 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -978,7 +978,7 @@ class SaveModelCallback(Callback): return save_pair, delete_pair def _save_this_model(self, metric_value): - name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) + name = "epoch-{}_step-{}_{}-{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) if save_pair: try: @@ -995,7 +995,7 @@ class SaveModelCallback(Callback): def on_exception(self, exception): if self.save_on_exception: - name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) + name = "epoch-{}_step-{}_Exception-{}.pt".format(self.epoch, self.step, exception.__class__.__name__) _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) diff --git a/test/core/test_dist_trainer.py b/test/core/test_dist_trainer.py index 3b53fe50..d2a11a76 100644 --- a/test/core/test_dist_trainer.py +++ b/test/core/test_dist_trainer.py @@ -148,7 +148,7 @@ class TestDistTrainer(unittest.TestCase): def run_dist(self, run_id): if torch.cuda.is_available(): ngpu = min(2, torch.cuda.device_count()) - path = __file__ + path = os.path.abspath(__file__) cmd = ['python', '-m', 'torch.distributed.launch', '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] print(' '.join(cmd)) diff --git a/test/models/test_cnn_text_classification.py b/test/models/test_cnn_text_classification.py index 2ea48220..29154bd6 100644 --- a/test/models/test_cnn_text_classification.py +++ b/test/models/test_cnn_text_classification.py @@ -6,12 +6,24 @@ from fastNLP.models.cnn_text_classification import CNNText class TestCNNText(unittest.TestCase): + def init_model(self, kernel_sizes, kernel_nums=(1,3,5)): + model = CNNText((VOCAB_SIZE, 30), + NUM_CLS, + kernel_nums=kernel_nums, + kernel_sizes=kernel_sizes) + return model + def test_case1(self): # 测试能否正常运行CNN - init_emb = (VOCAB_SIZE, 30) - model = CNNText(init_emb, - NUM_CLS, - kernel_nums=(1, 3, 5), - kernel_sizes=(1, 3, 5), - dropout=0.5) + model = self.init_model((1,3,5)) + RUNNER.run_model_with_task(TEXT_CLS, model) + + def test_init_model(self): + self.assertRaises(Exception, self.init_model, (2,4)) + self.assertRaises(Exception, self.init_model, (2,)) + + def test_output(self): + model = self.init_model((3,), (1,)) + global MAX_LEN + MAX_LEN = 2 RUNNER.run_model_with_task(TEXT_CLS, model)