Browse Source

[bugfix] fix test cases

tags/v0.4.10
yunfan 5 years ago
parent
commit
c73f327a3e
3 changed files with 21 additions and 9 deletions
  1. +2
    -2
      fastNLP/core/callback.py
  2. +1
    -1
      test/core/test_dist_trainer.py
  3. +18
    -6
      test/models/test_cnn_text_classification.py

+ 2
- 2
fastNLP/core/callback.py View File

@@ -978,7 +978,7 @@ class SaveModelCallback(Callback):
return save_pair, delete_pair

def _save_this_model(self, metric_value):
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
name = "epoch-{}_step-{}_{}-{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name))
if save_pair:
try:
@@ -995,7 +995,7 @@ class SaveModelCallback(Callback):

def on_exception(self, exception):
if self.save_on_exception:
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
name = "epoch-{}_step-{}_Exception-{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)




+ 1
- 1
test/core/test_dist_trainer.py View File

@@ -148,7 +148,7 @@ class TestDistTrainer(unittest.TestCase):
def run_dist(self, run_id):
if torch.cuda.is_available():
ngpu = min(2, torch.cuda.device_count())
path = __file__
path = os.path.abspath(__file__)
cmd = ['python', '-m', 'torch.distributed.launch',
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)]
print(' '.join(cmd))


+ 18
- 6
test/models/test_cnn_text_classification.py View File

@@ -6,12 +6,24 @@ from fastNLP.models.cnn_text_classification import CNNText


class TestCNNText(unittest.TestCase):
def init_model(self, kernel_sizes, kernel_nums=(1,3,5)):
model = CNNText((VOCAB_SIZE, 30),
NUM_CLS,
kernel_nums=kernel_nums,
kernel_sizes=kernel_sizes)
return model

def test_case1(self):
# 测试能否正常运行CNN
init_emb = (VOCAB_SIZE, 30)
model = CNNText(init_emb,
NUM_CLS,
kernel_nums=(1, 3, 5),
kernel_sizes=(1, 3, 5),
dropout=0.5)
model = self.init_model((1,3,5))
RUNNER.run_model_with_task(TEXT_CLS, model)

def test_init_model(self):
self.assertRaises(Exception, self.init_model, (2,4))
self.assertRaises(Exception, self.init_model, (2,))

def test_output(self):
model = self.init_model((3,), (1,))
global MAX_LEN
MAX_LEN = 2
RUNNER.run_model_with_task(TEXT_CLS, model)

Loading…
Cancel
Save