|
@@ -12,6 +12,7 @@ from fastNLP import AccuracyMetric |
|
|
from fastNLP import SGD |
|
|
from fastNLP import SGD |
|
|
from fastNLP import Trainer |
|
|
from fastNLP import Trainer |
|
|
from fastNLP.models.base_model import NaiveClassifier |
|
|
from fastNLP.models.base_model import NaiveClassifier |
|
|
|
|
|
from fastNLP.core.callback import EarlyStopError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_env(): |
|
|
def prepare_env(): |
|
@@ -65,7 +66,8 @@ class TestCallback(unittest.TestCase): |
|
|
dev_data=data_set, |
|
|
dev_data=data_set, |
|
|
metrics=AccuracyMetric(pred="predict", target="y"), |
|
|
metrics=AccuracyMetric(pred="predict", target="y"), |
|
|
callbacks=[EarlyStopCallback(5)]) |
|
|
callbacks=[EarlyStopCallback(5)]) |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
with self.assertRaises(EarlyStopError): |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
def test_lr_scheduler(self): |
|
|
def test_lr_scheduler(self): |
|
|
data_set, model = prepare_env() |
|
|
data_set, model = prepare_env() |
|
|