diff --git a/README.md b/README.md index 4068febc..efa1b15b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,126 @@ -# FastNLP +# fastNLP + [![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP) [![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP) + +fastNLP is a modular Natural Language Processing system based on PyTorch, for fast development of NLP tools. It divides the NLP model based on deep learning into different modules. These modules fall into 4 categories: encoder, interaction, aggregation and decoder, while each category contains different implemented modules. Encoder modules encode the input into some abstract representation, interaction modules make the information in the representation interact with each other, aggregation modules aggregate and reduce information, and decoder modules decode the representation into the output. Most current NLP models could be built on these modules, which vastly simplifies the process of developing NLP models. The architecture of fastNLP is as the figure below: + +![](https://github.com/fastnlp/fastNLP/raw/master/fastnlp-architecture.pdf) + + +## Requirements + +- numpy>=1.14.2 +- torch==0.4.0 +- torchvision>=0.1.8 + + +## Resources + +- [Documentation](https://github.com/fastnlp/fastNLP) +- [Source Code](https://github.com/fastnlp/fastNLP) + + +## Example + +### Basic Usage + +A typical fastNLP routine is composed of four phases: loading dataset, pre-processing data, constructing model and training model. +```python +from fastNLP.models.base_model import BaseModel +from fastNLP.modules import encoder +from fastNLP.modules import aggregation + +from fastNLP.loader.dataset_loader import ClassDatasetLoader +from fastNLP.loader.preprocess import ClassPreprocess +from fastNLP.core.trainer import ClassificationTrainer +from fastNLP.core.inference import ClassificationInfer + + +class ClassificationModel(BaseModel): + """ + Simple text classification model based on CNN. + """ + + def __init__(self, class_num, vocab_size): + super(ClassificationModel, self).__init__() + + self.embed = encoder.Embedding(nums=vocab_size, dims=300) + self.conv = encoder.Conv( + in_channels=300, out_channels=100, kernel_size=3) + self.pool = aggregation.MaxPool() + self.output = encoder.Linear(input_size=100, output_size=class_num) + + def forward(self, x): + x = self.embed(x) # [N,L] -> [N,L,C] + x = self.conv(x) # [N,L,C_in] -> [N,L,C_out] + x = self.pool(x) # [N,L,C] -> [N,C] + x = self.output(x) # [N,C] -> [N, N_class] + return x + + +data_dir = 'data' # directory to save data and model +train_path = 'test/data_for_tests/text_classify.txt' # training set file + +# load dataset +ds_loader = ClassDatasetLoader("train", train_path) +data = ds_loader.load() + +# pre-process dataset +pre = ClassPreprocess(data_dir) +vocab_size, n_classes = pre.process(data, "data_train.pkl") + +# construct model +model_args = { + 'num_classes': n_classes, + 'vocab_size': vocab_size +} +model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size) + +# train model +train_args = { + "epochs": 20, + "batch_size": 50, + "pickle_path": data_dir, + "validate": False, + "save_best_dev": False, + "model_saved_path": None, + "use_cuda": True, + "learn_rate": 1e-3, + "momentum": 0.9} +trainer = ClassificationTrainer(train_args) +trainer.train(model) + +# predict using model +seqs = [x[0] for x in data] +infer = ClassificationInfer(data_dir) +labels_pred = infer.predict(model, seqs) +``` + + +## Installation + +### Cloning From GitHub + +If you just want to use fastNLP, use: +```shell +git clone https://github.com/fastnlp/fastNLP +cd fastNLP +``` + +### PyTorch Installation + +Visit the [PyTorch official website] for installation instructions based on your system. In general, you could use: +```shell +# using conda +conda install pytorch torchvision -c pytorch +# or using pip +pip3 install torch torchvision +``` + + +## Project Structure + ``` FastNLP ├── docs @@ -90,5 +210,4 @@ FastNLP ├── test_seq_labeling.py ├── test_tester.py └── test_trainer.py - ``` diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 172cf584..758f5efb 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -196,6 +196,6 @@ class ClassificationInfer(Predictor): """ results = [] for batch_out in batch_outputs: - idx = np.argmax(batch_out.detach().numpy()) - results.append(self.index2label[idx]) + idx = np.argmax(batch_out.detach().numpy(), axis=-1) + results.extend([self.index2label[i] for i in idx]) return results diff --git a/fastNLP/modules/aggregation/__init__.py b/fastNLP/modules/aggregation/__init__.py index e69de29b..3c57625b 100644 --- a/fastNLP/modules/aggregation/__init__.py +++ b/fastNLP/modules/aggregation/__init__.py @@ -0,0 +1,5 @@ +from .max_pool import MaxPool + +__all__ = [ + 'MaxPool' +] diff --git a/fastNLP/modules/aggregation/max_pool.py b/fastNLP/modules/aggregation/max_pool.py index 12bdd96f..2e47a4f5 100644 --- a/fastNLP/modules/aggregation/max_pool.py +++ b/fastNLP/modules/aggregation/max_pool.py @@ -1,6 +1,7 @@ # python: 3.6 # encoding: utf-8 +import torch import torch.nn as nn import torch.nn.functional as F @@ -15,12 +16,12 @@ class MaxPool(nn.Module): self.dilation = dilation def forward(self, x): - # [N,C,L] -> [N,C] + x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] kernel_size = x.size(2) - x = F.max_pool1d( + x = F.max_pool1d( # [N,L,C] -> [N,C,1] input=x, kernel_size=kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation) - return x.squeeze(dim=-1) + return x.squeeze(dim=-1) # [N,C,1] -> [N,C] diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index 25d1b1db..b4e689a7 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -1,7 +1,9 @@ from .embedding import Embedding from .linear import Linear from .lstm import Lstm +from .conv import Conv __all__ = ["Lstm", "Embedding", - "Linear"] + "Linear", + "Conv"] diff --git a/fastNLP/modules/encoder/conv.py b/fastNLP/modules/encoder/conv.py index 1aeedbd5..06a31dd8 100644 --- a/fastNLP/modules/encoder/conv.py +++ b/fastNLP/modules/encoder/conv.py @@ -1,8 +1,9 @@ # python: 3.6 # encoding: utf-8 +import torch import torch.nn as nn -from torch.nn.init import xavier_uniform +from torch.nn.init import xavier_uniform_ # import torch.nn.functional as F @@ -14,7 +15,7 @@ class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, - groups=1, bias=True): + groups=1, bias=True, activation='relu'): super(Conv, self).__init__() self.conv = nn.Conv1d( in_channels=in_channels, @@ -25,7 +26,21 @@ class Conv(nn.Module): dilation=dilation, groups=groups, bias=bias) - xavier_uniform(self.conv.weight) + xavier_uniform_(self.conv.weight) + + activations = { + 'relu': nn.ReLU(), + 'tanh': nn.Tanh()} + if activation in activations: + self.activation = activations[activation] + else: + raise Exception( + 'Should choose activation function from: ' + + ', '.join([x for x in activations])) def forward(self, x): - return self.conv(x) # [N,C,L] + x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] + x = self.conv(x) # [N,C_in,L] -> [N,C_out,L] + x = self.activation(x) + x = torch.transpose(x, 1, 2) # [N,C,L] -> [N,L,C] + return x diff --git a/fastnlp-architecture.pdf b/fastnlp-architecture.pdf new file mode 100644 index 00000000..3480f025 Binary files /dev/null and b/fastnlp-architecture.pdf differ