Browse Source

Merge pull request #46 from fastnlp/modify-readme-example

modify readme example
tags/v0.1.0
Coet GitHub 6 years ago
parent
commit
96391d6ab3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 21 deletions
  1. +11
    -10
      README.md
  2. +12
    -11
      test/readme_example.py

+ 11
- 10
README.md View File

@@ -30,6 +30,7 @@ A typical fastNLP routine is composed of four phases: loading dataset, pre-proce
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import encoder
from fastNLP.modules import aggregation
from fastNLP.modules import decoder

from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.preprocess import ClassPreprocess
@@ -42,20 +43,20 @@ class ClassificationModel(BaseModel):
Simple text classification model based on CNN.
"""

def __init__(self, class_num, vocab_size):
def __init__(self, num_classes, vocab_size):
super(ClassificationModel, self).__init__()

self.embed = encoder.Embedding(nums=vocab_size, dims=300)
self.conv = encoder.Conv(
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
self.enc = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3)
self.pool = aggregation.MaxPool()
self.output = encoder.Linear(input_size=100, output_size=class_num)
self.agg = aggregation.MaxPool()
self.dec = decoder.MLP(100, num_classes=num_classes)

def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out]
x = self.pool(x) # [N,L,C] -> [N,C]
x = self.output(x) # [N,C] -> [N, N_class]
x = self.emb(x) # [N,L] -> [N,L,C]
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
x = self.agg(x) # [N,L,C] -> [N,C]
x = self.dec(x) # [N,C] -> [N, N_class]
return x


@@ -75,7 +76,7 @@ model_args = {
'num_classes': n_classes,
'vocab_size': vocab_size
}
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size)
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)

# train model
train_args = {


+ 12
- 11
test/readme_example.py View File

@@ -13,6 +13,7 @@ from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import aggregation
from fastNLP.modules import encoder
from fastNLP.modules import decoder


class ClassificationModel(BaseModel):
@@ -20,20 +21,20 @@ class ClassificationModel(BaseModel):
Simple text classification model based on CNN.
"""

def __init__(self, class_num, vocab_size):
def __init__(self, num_classes, vocab_size):
super(ClassificationModel, self).__init__()

self.embed = encoder.Embedding(nums=vocab_size, dims=300)
self.conv = encoder.Conv(
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
self.enc = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3)
self.pool = aggregation.MaxPool()
self.output = encoder.Linear(input_size=100, output_size=class_num)
self.agg = aggregation.MaxPool()
self.dec = decoder.MLP(100, num_classes=num_classes)

def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out]
x = self.pool(x) # [N,L,C] -> [N,C]
x = self.output(x) # [N,C] -> [N, N_class]
x = self.emb(x) # [N,L] -> [N,L,C]
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
x = self.agg(x) # [N,L,C] -> [N,C]
x = self.dec(x) # [N,C] -> [N, N_class]
return x


@@ -55,7 +56,7 @@ model_args = {
'num_classes': n_classes,
'vocab_size': vocab_size
}
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size)
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)

# train model
train_args = {
@@ -75,4 +76,4 @@ trainer.cross_validate(model)
# predict using model
data_infer = [x[0] for x in data]
infer = ClassificationInfer(data_dir)
labels_pred = infer.predict(model, data_infer)
labels_pred = infer.predict(model, data_infer)

Loading…
Cancel
Save