|
|
@@ -134,7 +134,7 @@ class TestCallback(unittest.TestCase): |
|
|
|
|
|
|
|
def test_fitlog_callback(self): |
|
|
|
import fitlog |
|
|
|
fitlog.set_log_dir(self.tempdir) |
|
|
|
fitlog.set_log_dir(self.tempdir, new_log=True) |
|
|
|
data_set, model = prepare_env() |
|
|
|
from fastNLP import Tester |
|
|
|
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) |
|
|
@@ -164,7 +164,7 @@ class TestCallback(unittest.TestCase): |
|
|
|
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) |
|
|
|
import fitlog |
|
|
|
|
|
|
|
fitlog.set_log_dir(self.tempdir) |
|
|
|
fitlog.set_log_dir(self.tempdir, new_log=True) |
|
|
|
tempfile_path = os.path.join(self.tempdir, 'chkt.pt') |
|
|
|
callbacks = [CheckPointCallback(tempfile_path)] |
|
|
|
|
|
|
|