|
|
@@ -13,11 +13,11 @@ from fastNLP.models.base_model import NaiveClassifier |
|
|
|
|
|
|
|
class TrainerTestGround(unittest.TestCase): |
|
|
|
def test_case(self): |
|
|
|
mean = np.array([-5, -5]) |
|
|
|
mean = np.array([-3, -3]) |
|
|
|
cov = np.array([[1, 0], [0, 1]]) |
|
|
|
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) |
|
|
|
|
|
|
|
mean = np.array([5, 5]) |
|
|
|
mean = np.array([3, 3]) |
|
|
|
cov = np.array([[1, 0], [0, 1]]) |
|
|
|
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) |
|
|
|
|
|
|
@@ -39,7 +39,7 @@ class TrainerTestGround(unittest.TestCase): |
|
|
|
print_every=10, |
|
|
|
validate_every=-1, |
|
|
|
dev_data=dev_set, |
|
|
|
optimizer=SGD(0.001), |
|
|
|
optimizer=SGD(0.1), |
|
|
|
check_code_level=2 |
|
|
|
) |
|
|
|
trainer.train() |