|
@@ -22,6 +22,9 @@ class BaseModel(torch.nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NaiveClassifier(BaseModel): |
|
|
class NaiveClassifier(BaseModel): |
|
|
|
|
|
""" |
|
|
|
|
|
一个简单的分类器例子,可用于各种测试 |
|
|
|
|
|
""" |
|
|
def __init__(self, in_feature_dim, out_feature_dim): |
|
|
def __init__(self, in_feature_dim, out_feature_dim): |
|
|
super(NaiveClassifier, self).__init__() |
|
|
super(NaiveClassifier, self).__init__() |
|
|
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) |
|
|
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) |
|
|