Browse Source

Update readme, conv and max_pool

tags/v0.1.0
Ke Zhen 6 years ago
parent
commit
7f23b40ad7
7 changed files with 154 additions and 12 deletions
  1. +121
    -2
      README.md
  2. +2
    -2
      fastNLP/core/inference.py
  3. +5
    -0
      fastNLP/modules/aggregation/__init__.py
  4. +4
    -3
      fastNLP/modules/aggregation/max_pool.py
  5. +3
    -1
      fastNLP/modules/encoder/__init__.py
  6. +19
    -4
      fastNLP/modules/encoder/conv.py
  7. BIN
      fastnlp-architecture.pdf

+ 121
- 2
README.md View File

@@ -1,6 +1,126 @@
# FastNLP
# fastNLP

[![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP) [![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP)
[![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP) [![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP)

fastNLP is a modular Natural Language Processing system based on PyTorch, for fast development of NLP tools. It divides the NLP model based on deep learning into different modules. These modules fall into 4 categories: encoder, interaction, aggregation and decoder, while each category contains different implemented modules. Encoder modules encode the input into some abstract representation, interaction modules make the information in the representation interact with each other, aggregation modules aggregate and reduce information, and decoder modules decode the representation into the output. Most current NLP models could be built on these modules, which vastly simplifies the process of developing NLP models. The architecture of fastNLP is as the figure below:

![](fastnlp-architecture.pdf)


## Requirements

- numpy>=1.14.2
- torch==0.4.0
- torchvision>=0.1.8


## Resources

- [Documentation](https://github.com/fastnlp/fastNLP)
- [Source Code](https://github.com/fastnlp/fastNLP)


## Example

### Basic Usage

A typical fastNLP routine is composed of four phases: loading dataset, pre-processing data, constructing model and training model.
```python
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import encoder
from fastNLP.modules import aggregation

from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.core.inference import ClassificationInfer


class ClassificationModel(BaseModel):
"""
Simple text classification model based on CNN.
"""

def __init__(self, class_num, vocab_size):
super(ClassificationModel, self).__init__()

self.embed = encoder.Embedding(nums=vocab_size, dims=300)
self.conv = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3)
self.pool = aggregation.MaxPool()
self.output = encoder.Linear(input_size=100, output_size=class_num)

def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out]
x = self.pool(x) # [N,L,C] -> [N,C]
x = self.output(x) # [N,C] -> [N, N_class]
return x


data_dir = 'data' # directory to save data and model
train_path = 'test/data_for_tests/text_classify.txt' # training set file

# load dataset
ds_loader = ClassDatasetLoader("train", train_path)
data = ds_loader.load()

# pre-process dataset
pre = ClassPreprocess(data_dir)
vocab_size, n_classes = pre.process(data, "data_train.pkl")

# construct model
model_args = {
'num_classes': n_classes,
'vocab_size': vocab_size
}
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size)

# train model
train_args = {
"epochs": 20,
"batch_size": 50,
"pickle_path": data_dir,
"validate": False,
"save_best_dev": False,
"model_saved_path": None,
"use_cuda": True,
"learn_rate": 1e-3,
"momentum": 0.9}
trainer = ClassificationTrainer(train_args)
trainer.train(model)

# predict using model
seqs = [x[0] for x in data]
infer = ClassificationInfer(data_dir)
labels_pred = infer.predict(model, seqs)
```


## Installation

### Cloning From GitHub

If you just want to use fastNLP, use:
```shell
git clone https://github.com/fastnlp/fastNLP
cd fastNLP
```

### PyTorch Installation

Visit the [PyTorch official website] for installation instructions based on your system. In general, you could use:
```shell
# using conda
conda install pytorch torchvision -c pytorch
# or using pip
pip3 install torch torchvision
```


## Project Structure

``` ```
FastNLP FastNLP
├── docs ├── docs
@@ -90,5 +210,4 @@ FastNLP
├── test_seq_labeling.py ├── test_seq_labeling.py
├── test_tester.py ├── test_tester.py
└── test_trainer.py └── test_trainer.py

``` ```

+ 2
- 2
fastNLP/core/inference.py View File

@@ -182,6 +182,6 @@ class ClassificationInfer(Inference):
""" """
results = [] results = []
for batch_out in batch_outputs: for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy())
results.append(self.index2label[idx])
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
results.extend([self.index2label[i] for i in idx])
return results return results

+ 5
- 0
fastNLP/modules/aggregation/__init__.py View File

@@ -0,0 +1,5 @@
from .max_pool import MaxPool

__all__ = [
'MaxPool'
]

+ 4
- 3
fastNLP/modules/aggregation/max_pool.py View File

@@ -1,6 +1,7 @@
# python: 3.6 # python: 3.6
# encoding: utf-8 # encoding: utf-8


import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F


@@ -15,12 +16,12 @@ class MaxPool(nn.Module):
self.dilation = dilation self.dilation = dilation


def forward(self, x): def forward(self, x):
# [N,C,L] -> [N,C]
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L]
kernel_size = x.size(2) kernel_size = x.size(2)
x = F.max_pool1d(
x = F.max_pool1d( # [N,L,C] -> [N,C,1]
input=x, input=x,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=self.stride, stride=self.stride,
padding=self.padding, padding=self.padding,
dilation=self.dilation) dilation=self.dilation)
return x.squeeze(dim=-1)
return x.squeeze(dim=-1) # [N,C,1] -> [N,C]

+ 3
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -1,7 +1,9 @@
from .embedding import Embedding from .embedding import Embedding
from .linear import Linear from .linear import Linear
from .lstm import Lstm from .lstm import Lstm
from .conv import Conv


__all__ = ["Lstm", __all__ = ["Lstm",
"Embedding", "Embedding",
"Linear"]
"Linear",
"Conv"]

+ 19
- 4
fastNLP/modules/encoder/conv.py View File

@@ -1,8 +1,9 @@
# python: 3.6 # python: 3.6
# encoding: utf-8 # encoding: utf-8


import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.init import xavier_uniform
from torch.nn.init import xavier_uniform_
# import torch.nn.functional as F # import torch.nn.functional as F




@@ -14,7 +15,7 @@ class Conv(nn.Module):


def __init__(self, in_channels, out_channels, kernel_size, def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, stride=1, padding=0, dilation=1,
groups=1, bias=True):
groups=1, bias=True, activation='relu'):
super(Conv, self).__init__() super(Conv, self).__init__()
self.conv = nn.Conv1d( self.conv = nn.Conv1d(
in_channels=in_channels, in_channels=in_channels,
@@ -25,7 +26,21 @@ class Conv(nn.Module):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias) bias=bias)
xavier_uniform(self.conv.weight)
xavier_uniform_(self.conv.weight)

activations = {
'relu': nn.ReLU(),
'tanh': nn.Tanh()}
if activation in activations:
self.activation = activations[activation]
else:
raise Exception(
'Should choose activation function from: ' +
', '.join([x for x in activations]))


def forward(self, x): def forward(self, x):
return self.conv(x) # [N,C,L]
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L]
x = self.conv(x) # [N,C_in,L] -> [N,C_out,L]
x = self.activation(x)
x = torch.transpose(x, 1, 2) # [N,C,L] -> [N,L,C]
return x

BIN
fastnlp-architecture.pdf View File


Loading…
Cancel
Save