@@ -978,7 +978,7 @@ class SaveModelCallback(Callback): | |||||
return save_pair, delete_pair | return save_pair, delete_pair | ||||
def _save_this_model(self, metric_value): | def _save_this_model(self, metric_value): | ||||
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) | |||||
name = "epoch-{}_step-{}_{}-{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) | |||||
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) | save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) | ||||
if save_pair: | if save_pair: | ||||
try: | try: | ||||
@@ -995,7 +995,7 @@ class SaveModelCallback(Callback): | |||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
if self.save_on_exception: | if self.save_on_exception: | ||||
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) | |||||
name = "epoch-{}_step-{}_Exception-{}.pt".format(self.epoch, self.step, exception.__class__.__name__) | |||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | ||||
@@ -148,7 +148,7 @@ class TestDistTrainer(unittest.TestCase): | |||||
def run_dist(self, run_id): | def run_dist(self, run_id): | ||||
if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
ngpu = min(2, torch.cuda.device_count()) | ngpu = min(2, torch.cuda.device_count()) | ||||
path = __file__ | |||||
path = os.path.abspath(__file__) | |||||
cmd = ['python', '-m', 'torch.distributed.launch', | cmd = ['python', '-m', 'torch.distributed.launch', | ||||
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | ||||
print(' '.join(cmd)) | print(' '.join(cmd)) | ||||
@@ -6,12 +6,24 @@ from fastNLP.models.cnn_text_classification import CNNText | |||||
class TestCNNText(unittest.TestCase): | class TestCNNText(unittest.TestCase): | ||||
def init_model(self, kernel_sizes, kernel_nums=(1,3,5)): | |||||
model = CNNText((VOCAB_SIZE, 30), | |||||
NUM_CLS, | |||||
kernel_nums=kernel_nums, | |||||
kernel_sizes=kernel_sizes) | |||||
return model | |||||
def test_case1(self): | def test_case1(self): | ||||
# 测试能否正常运行CNN | # 测试能否正常运行CNN | ||||
init_emb = (VOCAB_SIZE, 30) | |||||
model = CNNText(init_emb, | |||||
NUM_CLS, | |||||
kernel_nums=(1, 3, 5), | |||||
kernel_sizes=(1, 3, 5), | |||||
dropout=0.5) | |||||
model = self.init_model((1,3,5)) | |||||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||||
def test_init_model(self): | |||||
self.assertRaises(Exception, self.init_model, (2,4)) | |||||
self.assertRaises(Exception, self.init_model, (2,)) | |||||
def test_output(self): | |||||
model = self.init_model((3,), (1,)) | |||||
global MAX_LEN | |||||
MAX_LEN = 2 | |||||
RUNNER.run_model_with_task(TEXT_CLS, model) | RUNNER.run_model_with_task(TEXT_CLS, model) |