diff --git a/README.md b/README.md index 6b2254b5..a38771ee 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ A typical fastNLP routine is composed of four phases: loading dataset, pre-proce from fastNLP.models.base_model import BaseModel from fastNLP.modules import encoder from fastNLP.modules import aggregation +from fastNLP.modules import decoder from fastNLP.loader.dataset_loader import ClassDatasetLoader from fastNLP.loader.preprocess import ClassPreprocess @@ -42,20 +43,20 @@ class ClassificationModel(BaseModel): Simple text classification model based on CNN. """ - def __init__(self, class_num, vocab_size): + def __init__(self, num_classes, vocab_size): super(ClassificationModel, self).__init__() - self.embed = encoder.Embedding(nums=vocab_size, dims=300) - self.conv = encoder.Conv( + self.emb = encoder.Embedding(nums=vocab_size, dims=300) + self.enc = encoder.Conv( in_channels=300, out_channels=100, kernel_size=3) - self.pool = aggregation.MaxPool() - self.output = encoder.Linear(input_size=100, output_size=class_num) + self.agg = aggregation.MaxPool() + self.dec = decoder.MLP(100, num_classes=num_classes) def forward(self, x): - x = self.embed(x) # [N,L] -> [N,L,C] - x = self.conv(x) # [N,L,C_in] -> [N,L,C_out] - x = self.pool(x) # [N,L,C] -> [N,C] - x = self.output(x) # [N,C] -> [N, N_class] + x = self.emb(x) # [N,L] -> [N,L,C] + x = self.enc(x) # [N,L,C_in] -> [N,L,C_out] + x = self.agg(x) # [N,L,C] -> [N,C] + x = self.dec(x) # [N,C] -> [N, N_class] return x @@ -75,7 +76,7 @@ model_args = { 'num_classes': n_classes, 'vocab_size': vocab_size } -model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size) +model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size) # train model train_args = { diff --git a/test/readme_example.py b/test/readme_example.py index 03cae2e6..17ac92c2 100644 --- a/test/readme_example.py +++ b/test/readme_example.py @@ -13,6 +13,7 @@ from fastNLP.loader.dataset_loader import ClassDatasetLoader from fastNLP.models.base_model import BaseModel from fastNLP.modules import aggregation from fastNLP.modules import encoder +from fastNLP.modules import decoder class ClassificationModel(BaseModel): @@ -20,20 +21,20 @@ class ClassificationModel(BaseModel): Simple text classification model based on CNN. """ - def __init__(self, class_num, vocab_size): + def __init__(self, num_classes, vocab_size): super(ClassificationModel, self).__init__() - self.embed = encoder.Embedding(nums=vocab_size, dims=300) - self.conv = encoder.Conv( + self.emb = encoder.Embedding(nums=vocab_size, dims=300) + self.enc = encoder.Conv( in_channels=300, out_channels=100, kernel_size=3) - self.pool = aggregation.MaxPool() - self.output = encoder.Linear(input_size=100, output_size=class_num) + self.agg = aggregation.MaxPool() + self.dec = decoder.MLP(100, num_classes=num_classes) def forward(self, x): - x = self.embed(x) # [N,L] -> [N,L,C] - x = self.conv(x) # [N,L,C_in] -> [N,L,C_out] - x = self.pool(x) # [N,L,C] -> [N,C] - x = self.output(x) # [N,C] -> [N, N_class] + x = self.emb(x) # [N,L] -> [N,L,C] + x = self.enc(x) # [N,L,C_in] -> [N,L,C_out] + x = self.agg(x) # [N,L,C] -> [N,C] + x = self.dec(x) # [N,C] -> [N, N_class] return x @@ -55,7 +56,7 @@ model_args = { 'num_classes': n_classes, 'vocab_size': vocab_size } -model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size) +model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size) # train model train_args = { @@ -75,4 +76,4 @@ trainer.cross_validate(model) # predict using model data_infer = [x[0] for x in data] infer = ClassificationInfer(data_dir) -labels_pred = infer.predict(model, data_infer) +labels_pred = infer.predict(model, data_infer) \ No newline at end of file