|
@@ -13,6 +13,7 @@ from fastNLP.loader.dataset_loader import ClassDatasetLoader |
|
|
from fastNLP.models.base_model import BaseModel |
|
|
from fastNLP.models.base_model import BaseModel |
|
|
from fastNLP.modules import aggregation |
|
|
from fastNLP.modules import aggregation |
|
|
from fastNLP.modules import encoder |
|
|
from fastNLP.modules import encoder |
|
|
|
|
|
from fastNLP.modules import decoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClassificationModel(BaseModel): |
|
|
class ClassificationModel(BaseModel): |
|
@@ -20,20 +21,20 @@ class ClassificationModel(BaseModel): |
|
|
Simple text classification model based on CNN. |
|
|
Simple text classification model based on CNN. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, class_num, vocab_size): |
|
|
|
|
|
|
|
|
def __init__(self, num_classes, vocab_size): |
|
|
super(ClassificationModel, self).__init__() |
|
|
super(ClassificationModel, self).__init__() |
|
|
|
|
|
|
|
|
self.embed = encoder.Embedding(nums=vocab_size, dims=300) |
|
|
|
|
|
self.conv = encoder.Conv( |
|
|
|
|
|
|
|
|
self.emb = encoder.Embedding(nums=vocab_size, dims=300) |
|
|
|
|
|
self.enc = encoder.Conv( |
|
|
in_channels=300, out_channels=100, kernel_size=3) |
|
|
in_channels=300, out_channels=100, kernel_size=3) |
|
|
self.pool = aggregation.MaxPool() |
|
|
|
|
|
self.output = encoder.Linear(input_size=100, output_size=class_num) |
|
|
|
|
|
|
|
|
self.agg = aggregation.MaxPool() |
|
|
|
|
|
self.dec = decoder.MLP(100, num_classes=num_classes) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
x = self.embed(x) # [N,L] -> [N,L,C] |
|
|
|
|
|
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out] |
|
|
|
|
|
x = self.pool(x) # [N,L,C] -> [N,C] |
|
|
|
|
|
x = self.output(x) # [N,C] -> [N, N_class] |
|
|
|
|
|
|
|
|
x = self.emb(x) # [N,L] -> [N,L,C] |
|
|
|
|
|
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out] |
|
|
|
|
|
x = self.agg(x) # [N,L,C] -> [N,C] |
|
|
|
|
|
x = self.dec(x) # [N,C] -> [N, N_class] |
|
|
return x |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -55,7 +56,7 @@ model_args = { |
|
|
'num_classes': n_classes, |
|
|
'num_classes': n_classes, |
|
|
'vocab_size': vocab_size |
|
|
'vocab_size': vocab_size |
|
|
} |
|
|
} |
|
|
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size) |
|
|
|
|
|
|
|
|
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size) |
|
|
|
|
|
|
|
|
# train model |
|
|
# train model |
|
|
train_args = { |
|
|
train_args = { |
|
@@ -75,4 +76,4 @@ trainer.cross_validate(model) |
|
|
# predict using model |
|
|
# predict using model |
|
|
data_infer = [x[0] for x in data] |
|
|
data_infer = [x[0] for x in data] |
|
|
infer = ClassificationInfer(data_dir) |
|
|
infer = ClassificationInfer(data_dir) |
|
|
labels_pred = infer.predict(model, data_infer) |
|
|
|
|
|
|
|
|
labels_pred = infer.predict(model, data_infer) |