Browse Source

跑通test_trainer.py,联调结束,准备发布

tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
11c82ab2e7
2 changed files with 6 additions and 7 deletions
  1. +3
    -4
      fastNLP/models/base_model.py
  2. +3
    -3
      test/core/test_trainer.py

+ 3
- 4
fastNLP/models/base_model.py View File

@@ -20,11 +20,10 @@ class BaseModel(torch.nn.Module):
class NaiveClassifier(BaseModel):
def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim, out_feature_dim])
self.softmax = torch.nn.Softmax(dim=0)
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

def forward(self, x):
return {"predict": self.softmax(self.mlp(x))}
return {"predict": torch.sigmoid(self.mlp(x))}

def predict(self, x):
return {"predict": self.softmax(self.mlp(x))}
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}

+ 3
- 3
test/core/test_trainer.py View File

@@ -13,11 +13,11 @@ from fastNLP.models.base_model import NaiveClassifier

class TrainerTestGround(unittest.TestCase):
def test_case(self):
mean = np.array([-5, -5])
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([5, 5])
mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

@@ -39,7 +39,7 @@ class TrainerTestGround(unittest.TestCase):
print_every=10,
validate_every=-1,
dev_data=dev_set,
optimizer=SGD(0.001),
optimizer=SGD(0.1),
check_code_level=2
)
trainer.train()

Loading…
Cancel
Save