Browse Source

add reproduction files

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
c4028a528a
27 changed files with 61439 additions and 11 deletions
  1. +1
    -1
      fastNLP/loader/dataset_loader.py
  2. +39
    -7
      fastNLP/models/base_model.py
  3. +4
    -3
      fastNLP/models/word_seg_model.py
  4. +110
    -0
      fastNLP/reproduction/CNN-sentence_classification/.gitignore
  5. +77
    -0
      fastNLP/reproduction/CNN-sentence_classification/README.md
  6. +136
    -0
      fastNLP/reproduction/CNN-sentence_classification/dataset.py
  7. +35
    -0
      fastNLP/reproduction/CNN-sentence_classification/model.py
  8. +5331
    -0
      fastNLP/reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg
  9. +5331
    -0
      fastNLP/reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos
  10. +93
    -0
      fastNLP/reproduction/CNN-sentence_classification/train.py
  11. +21
    -0
      fastNLP/reproduction/Char-aware_NLM/LICENSE
  12. +40
    -0
      fastNLP/reproduction/Char-aware_NLM/README.md
  13. +145
    -0
      fastNLP/reproduction/Char-aware_NLM/model.py
  14. +117
    -0
      fastNLP/reproduction/Char-aware_NLM/test.py
  15. +3761
    -0
      fastNLP/reproduction/Char-aware_NLM/test.txt
  16. +263
    -0
      fastNLP/reproduction/Char-aware_NLM/train.py
  17. +42068
    -0
      fastNLP/reproduction/Char-aware_NLM/train.txt
  18. +82
    -0
      fastNLP/reproduction/Char-aware_NLM/utilities.py
  19. +3370
    -0
      fastNLP/reproduction/Char-aware_NLM/valid.txt
  20. +36
    -0
      fastNLP/reproduction/HAN-document_classification/README.md
  21. BIN
      fastNLP/reproduction/HAN-document_classification/data/test_samples.pkl
  22. BIN
      fastNLP/reproduction/HAN-document_classification/data/train_samples.pkl
  23. BIN
      fastNLP/reproduction/HAN-document_classification/data/yelp.word2vec
  24. +45
    -0
      fastNLP/reproduction/HAN-document_classification/evaluate.py
  25. +113
    -0
      fastNLP/reproduction/HAN-document_classification/model.py
  26. +50
    -0
      fastNLP/reproduction/HAN-document_classification/preprocess.py
  27. +171
    -0
      fastNLP/reproduction/HAN-document_classification/train.py

+ 1
- 1
fastNLP/loader/dataset_loader.py View File

@@ -1,4 +1,4 @@
from loader.base_loader import BaseLoader
from fastNLP.loader.base_loader import BaseLoader




class DatasetLoader(BaseLoader): class DatasetLoader(BaseLoader):


+ 39
- 7
fastNLP/models/base_model.py View File

@@ -1,11 +1,40 @@
import numpy as np import numpy as np
import torch




class BaseModel(object):
"""The base class of all models.
This class and its subclasses are actually "wrappers" of the PyTorch models.
They act as an interface between Trainer and the deep learning networks.
This interface provides the following methods to be called by Trainer.
class BaseModel(torch.nn.Module):
"""Base PyTorch model for all models.
Three network modules presented:
- embedding module
- aggregation module
- output module
Subclasses must implement these three modules with "components".
"""

def __init__(self):
super(BaseModel, self).__init__()

def forward(self, *inputs):
x = self.encode(*inputs)
x = self.aggregation(x)
x = self.output(x)
return x

def encode(self, x):
raise NotImplementedError

def aggregation(self, x):
raise NotImplementedError

def output(self, x):
raise NotImplementedError


class BaseController(object):
"""Base Controller for all controllers.
This class and its subclasses are actually "controllers" of the PyTorch models.
They act as an interface between Trainer and the PyTorch models.
This controller provides the following methods to be called by Trainer.
- prepare_input - prepare_input
- mode - mode
- define_optimizer - define_optimizer
@@ -15,6 +44,9 @@ class BaseModel(object):
""" """


def __init__(self): def __init__(self):
"""
Define PyTorch model parameters here.
"""
pass pass


def prepare_input(self, data): def prepare_input(self, data):
@@ -63,11 +95,11 @@ class BaseModel(object):
raise NotImplementedError raise NotImplementedError




class ToyModel(BaseModel):
class ToyController(BaseController):
"""This is for code testing.""" """This is for code testing."""


def __init__(self): def __init__(self):
super(ToyModel, self).__init__()
super(ToyController, self).__init__()
self.test_mode = False self.test_mode = False
self.weight = np.random.rand(5, 1) self.weight = np.random.rand(5, 1)
self.bias = np.random.rand() self.bias = np.random.rand()


+ 4
- 3
fastNLP/models/word_seg_model.py View File

@@ -2,9 +2,10 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from model.base_model import BaseModel
from torch.autograd import Variable from torch.autograd import Variable


from fastNLP.models.base_model import BaseModel, BaseController

USE_GPU = True USE_GPU = True




@@ -14,7 +15,7 @@ def to_var(x):
return Variable(x) return Variable(x)




class WordSegModel(BaseModel):
class WordSegModel(BaseController):
""" """
Model controller for WordSeg Model controller for WordSeg
""" """
@@ -91,7 +92,7 @@ class WordSegModel(BaseModel):
self.optimizer.step() self.optimizer.step()




class WordSeg(nn.Module):
class WordSeg(BaseModel):
""" """
PyTorch Network for word segmentation PyTorch Network for word segmentation
""" """


+ 110
- 0
fastNLP/reproduction/CNN-sentence_classification/.gitignore View File

@@ -0,0 +1,110 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache

#custom
GoogleNews-vectors-negative300.bin/
GoogleNews-vectors-negative300.bin.gz
models/
*.swp

+ 77
- 0
fastNLP/reproduction/CNN-sentence_classification/README.md View File

@@ -0,0 +1,77 @@
## Introduction
This is the implementation of [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882) paper in PyTorch.
* MRDataset, non-static-model(word2vec rained by Mikolov etal. (2013) on 100 billion words of Google News)
* It can be run in both CPU and GPU
* The best accuracy is 82.61%, which is better than 81.5% in the paper
(by Jingyuan Liu @Fudan University; Email:(fdjingyuan@outlook.com) Welcome to discussion!)

## Requirement
* python 3.6
* pytorch > 0.1
* numpy
* gensim

## Run
STEP 1
install packages like gensim (other needed pakages is the same)
```
pip install gensim
```

STEP 2
install MRdataset and word2vec resources
* MRdataset: you can download the dataset in (https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz)
* word2vec: you can download the file in (https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit)

Since this file is more than 1.5G, I did not display in folders. If you download the file, please remember modify the path in Function def word_embeddings(path = './GoogleNews-vectors-negative300.bin/'):


STEP 3
train the model
```
python train.py
```
you will get the information printed in the screen, like
```
Epoch [1/20], Iter [100/192] Loss: 0.7008
Test Accuracy: 71.869159 %
Epoch [2/20], Iter [100/192] Loss: 0.5957
Test Accuracy: 75.700935 %
Epoch [3/20], Iter [100/192] Loss: 0.4934
Test Accuracy: 78.130841 %

......
Epoch [20/20], Iter [100/192] Loss: 0.0364
Test Accuracy: 81.495327 %
Best Accuracy: 82.616822 %
Best Model: models/cnn.pkl
```

## Hyperparameters
According to the paper and experiment, I set:

|Epoch|Kernel Size|dropout|learning rate|batch size|
|---|---|---|---|---|
|20|\(h,300,100\)|0.5|0.0001|50|

h = [3,4,5]
If the accuracy is not improved, the learning rate will \*0.8.

## Result
I just tried one dataset : MR. (Other 6 dataset in paper SST-1, SST-2, TREC, CR, MPQA)
There are four models in paper: CNN-rand, CNN-static, CNN-non-static, CNN-multichannel.
I have tried CNN-non-static:A model with pre-trained vectors from word2vec.
All words—including the unknown ones that are randomly initialized and the pretrained vectors are fine-tuned for each task
(which has almost the best performance and the most difficut to implement among the four models)

|Dataset|Class Size|Best Result|Kim's Paper Result|
|---|---|---|---|
|MR|2|82.617%(CNN-non-static)|81.5%(CNN-nonstatic)|



## Reference
* [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882)
* https://github.com/Shawn1993/cnn-text-classification-pytorch
* https://github.com/junwang4/CNN-sentence-classification-pytorch-2017/blob/master/utils.py


+ 136
- 0
fastNLP/reproduction/CNN-sentence_classification/dataset.py View File

@@ -0,0 +1,136 @@
import codecs
import random
import re

import gensim
import numpy as np
from gensim import corpora
from torch.utils.data import Dataset


def clean_str(string):
"""
Tokenization/string cleaning for all datasets except for SST.
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
"""
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip()


def pad_sentences(sentence, padding_word=" <PAD/>"):
sequence_length = 64
sent = sentence.split()
padded_sentence = sentence + padding_word * (sequence_length - len(sent))
return padded_sentence


# data loader
class MRDataset(Dataset):
def __init__(self):

# load positive and negative sentenses from files
with codecs.open("./rt-polaritydata/rt-polarity.pos", encoding='ISO-8859-1') as f:
positive_examples = list(f.readlines())
with codecs.open("./rt-polaritydata/rt-polarity.neg", encoding='ISO-8859-1') as f:
negative_examples = list(f.readlines())
# s.strip: clear "\n"; clear_str; pad
positive_examples = [pad_sentences(clean_str(s.strip())) for s in positive_examples]
negative_examples = [pad_sentences(clean_str(s.strip())) for s in negative_examples]
self.examples = positive_examples + negative_examples
self.sentences_texts = [sample.split() for sample in self.examples]

# word dictionary
dictionary = corpora.Dictionary(self.sentences_texts)
self.word2id_dict = dictionary.token2id # transform to dict, like {"human":0, "a":1,...}

# set lables: postive is 1; negative is 0
positive_labels = [1 for _ in positive_examples]
negative_labels = [0 for _ in negative_examples]
self.lables = positive_labels + negative_labels
examples_lables = list(zip(self.examples, self.lables))
random.shuffle(examples_lables)
self.MRDataset_frame = examples_lables

# transform word to id
self.MRDataset_wordid = \
[(
np.array([self.word2id_dict[word] for word in sent[0].split()], dtype=np.int64),
sent[1]
) for sent in self.MRDataset_frame]

def word_embeddings(self, path="./GoogleNews-vectors-negative300.bin/GoogleNews-vectors-negative300.bin"):
# establish from google
model = gensim.models.KeyedVectors.load_word2vec_format(path, binary=True)

print('Please wait ... (it could take a while to load the file : {})'.format(path))
word_dict = self.word2id_dict
embedding_weights = np.random.uniform(-0.25, 0.25, (len(self.word2id_dict), 300))

for word in word_dict:
word_id = word_dict[word]
if word in model.wv.vocab:
embedding_weights[word_id, :] = model[word]
return embedding_weights

def __len__(self):
return len(self.MRDataset_frame)

def __getitem__(self, idx):

sample = self.MRDataset_wordid[idx]
return sample

def getsent(self, idx):

sample = self.MRDataset_wordid[idx][0]
return sample

def getlabel(self, idx):

label = self.MRDataset_wordid[idx][1]
return label

def word2id(self):

return self.word2id_dict

def id2word(self):

id2word_dict = dict([val, key] for key, val in self.word2id_dict.items())
return id2word_dict


class train_set(Dataset):

def __init__(self, samples):
self.train_frame = samples

def __len__(self):
return len(self.train_frame)

def __getitem__(self, idx):
return self.train_frame[idx]


class test_set(Dataset):

def __init__(self, samples):
self.test_frame = samples

def __len__(self):
return len(self.test_frame)

def __getitem__(self, idx):
return self.test_frame[idx]

+ 35
- 0
fastNLP/reproduction/CNN-sentence_classification/model.py View File

@@ -0,0 +1,35 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class CNN_text(nn.Module):
def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, dropout=0.5, L2_constrain=3,
batchsize=50, pretrained_embeddings=None):
super(CNN_text, self).__init__()

self.embedding = nn.Embedding(embed_num, embed_dim)
self.dropout = nn.Dropout(dropout)
if pretrained_embeddings is not None:
self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))

# the network structure
# Conv2d: input- N,C,H,W output- (50,100,62,1)
self.conv1 = nn.ModuleList([nn.Conv2d(1, 100, (K, 300)) for K in kernel_h])
self.fc1 = nn.Linear(300, 2)

def max_pooling(self, x):
x = F.relu(conv(x)).squeeze(3) # N,C,L - (50,100,62)
x = F.max_pool1d(x, x.size(2)).squeeze(2)
# x.size(2)=62 squeeze: (50,100,1) -> (50,100)
return x

def forward(self, x):
x = self.embedding(x) # output: (N,H,W) = (50,64,300)
x = x.unsqueeze(1) # (N,C,H,W)
x = [F.relu(conv(x)).squeeze(3) for conv in self.conv1] # [N, C, H(50,100,62),(50,100,61),(50,100,60)]
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [N,C(50,100),(50,100),(50,100)]
x = torch.cat(x, 1)
x = self.dropout(x)
x = self.fc1(x)
return x

+ 5331
- 0
fastNLP/reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg
File diff suppressed because it is too large
View File


+ 5331
- 0
fastNLP/reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos
File diff suppressed because it is too large
View File


+ 93
- 0
fastNLP/reproduction/CNN-sentence_classification/train.py View File

@@ -0,0 +1,93 @@
import os

import
import
import torch
import torch.nn as nn
.dataset as dst
from .model import CNN_text
from torch.autograd import Variable

# Hyper Parameters
batch_size = 50
learning_rate = 0.0001
num_epochs = 20
cuda = True

# split Dataset
dataset = dst.MRDataset()
length = len(dataset)

train_dataset = dataset[:int(0.9 * length)]
test_dataset = dataset[int(0.9 * length):]

train_dataset = dst.train_set(train_dataset)
test_dataset = dst.test_set(test_dataset)

# Data Loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)

# cnn

cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings())
if cuda:
cnn.cuda()

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)

# train and test
best_acc = None

for epoch in range(num_epochs):
# Train the Model
cnn.train()
for i, (sents, labels) in enumerate(train_loader):
sents = Variable(sents)
labels = Variable(labels)
if cuda:
sents = sents.cuda()
labels = labels.cuda()
optimizer.zero_grad()
outputs = cnn(sents)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

if (i + 1) % 100 == 0:
print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
% (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0]))

# Test the Model
cnn.eval()
correct = 0
total = 0
for sents, labels in test_loader:
sents = Variable(sents)
if cuda:
sents = sents.cuda()
labels = labels.cuda()
outputs = cnn(sents)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
acc = 100. * correct / total
print('Test Accuracy: %f %%' % (acc))

if best_acc is None or acc > best_acc:
best_acc = acc
if os.path.exists("models") is False:
os.makedirs("models")
torch.save(cnn.state_dict(), 'models/cnn.pkl')
else:
learning_rate = learning_rate * 0.8

print("Best Accuracy: %f %%" % best_acc)
print("Best Model: models/cnn.pkl")

+ 21
- 0
fastNLP/reproduction/Char-aware_NLM/LICENSE View File

@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2017

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

+ 40
- 0
fastNLP/reproduction/Char-aware_NLM/README.md View File

@@ -0,0 +1,40 @@

# PyTorch-Character-Aware-Neural-Language-Model

This is the PyTorch implementation of character-aware neural language model proposed in this [paper](https://arxiv.org/abs/1508.06615) by Yoon Kim.

## Requiredments
The code is run and tested with **Python 3.5.2** and **PyTorch 0.3.1**.

## HyperParameters
| HyperParam | value |
| ------ | :-------|
| LSTM batch size | 20 |
| LSTM sequence length | 35 |
| LSTM hidden units | 300 |
| epochs | 35 |
| initial learning rate | 1.0 |
| character embedding dimension | 15 |

## Demo
Train the model with split train/valid/test data.

`python train.py`

The trained model will saved in `cache/net.pkl`.
Test the model.

`python test.py`

Best result on test set:
PPl=127.2163
cross entropy loss=4.8459

## Acknowledgement
This implementation borrowed ideas from

https://github.com/jarfo/kchar

https://github.com/cronos123/Character-Aware-Neural-Language-Models



+ 145
- 0
fastNLP/reproduction/Char-aware_NLM/model.py View File

@@ -0,0 +1,145 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class Highway(nn.Module):
"""Highway network"""

def __init__(self, input_size):
super(Highway, self).__init__()
self.fc1 = nn.Linear(input_size, input_size, bias=True)
self.fc2 = nn.Linear(input_size, input_size, bias=True)

def forward(self, x):
t = F.sigmoid(self.fc1(x))
return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x)


class charLM(nn.Module):
"""CNN + highway network + LSTM
# Input:
4D tensor with shape [batch_size, in_channel, height, width]
# Output:
2D Tensor with shape [batch_size, vocab_size]
# Arguments:
char_emb_dim: the size of each character's attention
word_emb_dim: the size of each word's attention
vocab_size: num of unique words
num_char: num of characters
use_gpu: True or False
"""

def __init__(self, char_emb_dim, word_emb_dim,
vocab_size, num_char, use_gpu):
super(charLM, self).__init__()
self.char_emb_dim = char_emb_dim
self.word_emb_dim = word_emb_dim
self.vocab_size = vocab_size

# char attention layer
self.char_embed = nn.Embedding(num_char, char_emb_dim)

# convolutions of filters with different sizes
self.convolutions = []

# list of tuples: (the number of filter, width)
self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)]

for out_channel, filter_width in self.filter_num_width:
self.convolutions.append(
nn.Conv2d(
1, # in_channel
out_channel, # out_channel
kernel_size=(char_emb_dim, filter_width), # (height, width)
bias=True
)
)

self.highway_input_dim = sum([x for x, y in self.filter_num_width])

self.batch_norm = nn.BatchNorm1d(self.highway_input_dim, affine=False)

# highway net
self.highway1 = Highway(self.highway_input_dim)
self.highway2 = Highway(self.highway_input_dim)

# LSTM
self.lstm_num_layers = 2

self.lstm = nn.LSTM(input_size=self.highway_input_dim,
hidden_size=self.word_emb_dim,
num_layers=self.lstm_num_layers,
bias=True,
dropout=0.5,
batch_first=True)

# output layer
self.dropout = nn.Dropout(p=0.5)
self.linear = nn.Linear(self.word_emb_dim, self.vocab_size)

if use_gpu is True:
for x in range(len(self.convolutions)):
self.convolutions[x] = self.convolutions[x].cuda()
self.highway1 = self.highway1.cuda()
self.highway2 = self.highway2.cuda()
self.lstm = self.lstm.cuda()
self.dropout = self.dropout.cuda()
self.char_embed = self.char_embed.cuda()
self.linear = self.linear.cuda()
self.batch_norm = self.batch_norm.cuda()

def forward(self, x, hidden):
# Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2]
# Return: Variable of Tensor with shape [num_words, len(word_dict)]
lstm_batch_size = x.size()[0]
lstm_seq_len = x.size()[1]

x = x.contiguous().view(-1, x.size()[2])
# [num_seq*seq_len, max_word_len+2]

x = self.char_embed(x)
# [num_seq*seq_len, max_word_len+2, char_emb_dim]

x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3)
# [num_seq*seq_len, 1, max_word_len+2, char_emb_dim]

x = self.conv_layers(x)
# [num_seq*seq_len, total_num_filters]

x = self.batch_norm(x)
# [num_seq*seq_len, total_num_filters]

x = self.highway1(x)
x = self.highway2(x)
# [num_seq*seq_len, total_num_filters]

x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1)
# [num_seq, seq_len, total_num_filters]

x, hidden = self.lstm(x, hidden)
# [seq_len, num_seq, hidden_size]

x = self.dropout(x)
# [seq_len, num_seq, hidden_size]

x = x.contiguous().view(lstm_batch_size * lstm_seq_len, -1)
# [num_seq*seq_len, hidden_size]

x = self.linear(x)
# [num_seq*seq_len, vocab_size]
return x, hidden

def conv_layers(self, x):
chosen_list = list()
for conv in self.convolutions:
feature_map = F.tanh(conv(x))
# (batch_size, out_channel, 1, max_word_len-width+1)
chosen = torch.max(feature_map, 3)[0]
# (batch_size, out_channel, 1)
chosen = chosen.squeeze()
# (batch_size, out_channel)
chosen_list.append(chosen)

# (batch_size, total_num_filers)
return torch.cat(chosen_list, 1)

+ 117
- 0
fastNLP/reproduction/Char-aware_NLM/test.py View File

@@ -0,0 +1,117 @@
import os
from collections import namedtuple

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from utilities import *


def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)


def test(net, data, opt):
net.eval()

test_input = torch.from_numpy(data.test_input)
test_label = torch.from_numpy(data.test_label)

num_seq = test_input.size()[0] // opt.lstm_seq_len
test_input = test_input[:num_seq * opt.lstm_seq_len, :]
# [num_seq, seq_len, max_word_len+2]
test_input = test_input.view(-1, opt.lstm_seq_len, opt.max_word_len + 2)

criterion = nn.CrossEntropyLoss()

loss_list = []
num_hits = 0
total = 0
iterations = test_input.size()[0] // opt.lstm_batch_size
test_generator = batch_generator(test_input, opt.lstm_batch_size)
label_generator = batch_generator(test_label, opt.lstm_batch_size * opt.lstm_seq_len)

hidden = (to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)),
to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)))

add_loss = 0.0
for t in range(iterations):
batch_input = test_generator.__next__()
batch_label = label_generator.__next__()

net.zero_grad()
hidden = [state.detach() for state in hidden]
test_output, hidden = net(to_var(batch_input), hidden)

test_loss = criterion(test_output, to_var(batch_label)).data
loss_list.append(test_loss)
add_loss += test_loss

print("Test Loss={0:.4f}".format(float(add_loss) / iterations))
print("Test PPL={0:.4f}".format(float(np.exp(add_loss / iterations))))


#############################################################

if __name__ == "__main__":

word_embed_dim = 300
char_embedding_dim = 15

if os.path.exists("cache/prep.pt") is False:
print("Cannot find prep.pt")

objetcs = torch.load("cache/prep.pt")

word_dict = objetcs["word_dict"]
char_dict = objetcs["char_dict"]
reverse_word_dict = objetcs["reverse_word_dict"]
max_word_len = objetcs["max_word_len"]
num_words = len(word_dict)

print("word/char dictionary built. Start making inputs.")

if os.path.exists("cache/data_sets.pt") is False:

test_text = read_data("./test.txt")
test_set = np.array(text2vec(test_text, char_dict, max_word_len))

# Labels are next-word index in word_dict with the same length as inputs
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]])

category = {"test": test_set, "tlabel": test_label}
torch.save(category, "cache/data_sets.pt")
else:
data_sets = torch.load("cache/data_sets.pt")
test_set = data_sets["test"]
test_label = data_sets["tlabel"]
train_set = data_sets["tdata"]
train_label = data_sets["trlabel"]

DataTuple = namedtuple("DataTuple", "test_input test_label train_input train_label ")
data = DataTuple(test_input=test_set,
test_label=test_label, train_label=train_label, train_input=train_set)

print("Loaded data sets. Start building network.")

USE_GPU = True
cnn_batch_size = 700
lstm_seq_len = 35
lstm_batch_size = 20

net = torch.load("cache/net.pkl")

Options = namedtuple("Options", ["cnn_batch_size", "lstm_seq_len",
"max_word_len", "lstm_batch_size", "word_embed_dim"])
opt = Options(cnn_batch_size=lstm_seq_len * lstm_batch_size,
lstm_seq_len=lstm_seq_len,
max_word_len=max_word_len,
lstm_batch_size=lstm_batch_size,
word_embed_dim=word_embed_dim)

print("Network built. Start testing.")

test(net, data, opt)

+ 3761
- 0
fastNLP/reproduction/Char-aware_NLM/test.txt
File diff suppressed because it is too large
View File


+ 263
- 0
fastNLP/reproduction/Char-aware_NLM/train.py View File

@@ -0,0 +1,263 @@
import os
from collections import namedtuple

import numpy as np
import torch.optim as optim

from .model import charLM
from .test import test
from .utilities import *


def preprocess():
word_dict, char_dict = create_word_char_dict("charlm.txt", "train.txt", "test.txt")
num_words = len(word_dict)
num_char = len(char_dict)
char_dict["BOW"] = num_char + 1
char_dict["EOW"] = num_char + 2
char_dict["PAD"] = 0

# dict of (int, string)
reverse_word_dict = {value: key for key, value in word_dict.items()}
max_word_len = max([len(word) for word in word_dict])

objects = {
"word_dict": word_dict,
"char_dict": char_dict,
"reverse_word_dict": reverse_word_dict,
"max_word_len": max_word_len
}

torch.save(objects, "cache/prep.pt")
print("Preprocess done.")


def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)


def train(net, data, opt):
"""
:param net: the pytorch models
:param data: numpy array
:param opt: named tuple
1. random seed
2. define local input
3. training settting: learning rate, loss, etc
4. main loop epoch
5. batchify
6. validation
7. save models
"""
torch.manual_seed(1024)

train_input = torch.from_numpy(data.train_input)
train_label = torch.from_numpy(data.train_label)
valid_input = torch.from_numpy(data.valid_input)
valid_label = torch.from_numpy(data.valid_label)

# [num_seq, seq_len, max_word_len+2]
num_seq = train_input.size()[0] // opt.lstm_seq_len
train_input = train_input[:num_seq * opt.lstm_seq_len, :]
train_input = train_input.view(-1, opt.lstm_seq_len, opt.max_word_len + 2)

num_seq = valid_input.size()[0] // opt.lstm_seq_len
valid_input = valid_input[:num_seq * opt.lstm_seq_len, :]
valid_input = valid_input.view(-1, opt.lstm_seq_len, opt.max_word_len + 2)

num_epoch = opt.epochs
num_iter_per_epoch = train_input.size()[0] // opt.lstm_batch_size

learning_rate = opt.init_lr
old_PPL = 100000
best_PPL = 100000

# Log-SoftMax
criterion = nn.CrossEntropyLoss()

# word_emb_dim == hidden_size / num of hidden units
hidden = (to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)),
to_var(torch.zeros(2, opt.lstm_batch_size, opt.word_embed_dim)))

for epoch in range(num_epoch):

################ Validation ####################
net.eval()
loss_batch = []
PPL_batch = []
iterations = valid_input.size()[0] // opt.lstm_batch_size

valid_generator = batch_generator(valid_input, opt.lstm_batch_size)
vlabel_generator = batch_generator(valid_label, opt.lstm_batch_size * opt.lstm_seq_len)

for t in range(iterations):
batch_input = valid_generator.__next__()
batch_label = vlabel_generator.__next__()

hidden = [state.detach() for state in hidden]
valid_output, hidden = net(to_var(batch_input), hidden)

length = valid_output.size()[0]

# [num_sample-1, len(word_dict)] vs [num_sample-1]
valid_loss = criterion(valid_output, to_var(batch_label))

PPL = torch.exp(valid_loss.data)

loss_batch.append(float(valid_loss))
PPL_batch.append(float(PPL))

PPL = np.mean(PPL_batch)
print("[epoch {}] valid PPL={}".format(epoch, PPL))
print("valid loss={}".format(np.mean(loss_batch)))
print("PPL decrease={}".format(float(old_PPL - PPL)))

# Preserve the best models
if best_PPL > PPL:
best_PPL = PPL
torch.save(net.state_dict(), "cache/models.pt")
torch.save(net, "cache/net.pkl")

# Adjust the learning rate
if float(old_PPL - PPL) <= 1.0:
learning_rate /= 2
print("halved lr:{}".format(learning_rate))

old_PPL = PPL

##################################################
#################### Training ####################
net.train()
optimizer = optim.SGD(net.parameters(),
lr=learning_rate,
momentum=0.85)

# split the first dim
input_generator = batch_generator(train_input, opt.lstm_batch_size)
label_generator = batch_generator(train_label, opt.lstm_batch_size * opt.lstm_seq_len)

for t in range(num_iter_per_epoch):
batch_input = input_generator.__next__()
batch_label = label_generator.__next__()

# detach hidden state of LSTM from last batch
hidden = [state.detach() for state in hidden]

output, hidden = net(to_var(batch_input), hidden)
# [num_word, vocab_size]

loss = criterion(output, to_var(batch_label))

net.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(net.parameters(), 5, norm_type=2)
optimizer.step()

if (t + 1) % 100 == 0:
print("[epoch {} step {}] train loss={}, Perplexity={}".format(epoch + 1,
t + 1, float(loss.data),
float(np.exp(loss.data))))

torch.save(net.state_dict(), "cache/models.pt")
print("Training finished.")


################################################################

if __name__ == "__main__":

word_embed_dim = 300
char_embedding_dim = 15

if os.path.exists("cache/prep.pt") is False:
preprocess()

objetcs = torch.load("cache/prep.pt")

word_dict = objetcs["word_dict"]
char_dict = objetcs["char_dict"]
reverse_word_dict = objetcs["reverse_word_dict"]
max_word_len = objetcs["max_word_len"]
num_words = len(word_dict)

print("word/char dictionary built. Start making inputs.")

if os.path.exists("cache/data_sets.pt") is False:
train_text = read_data("./train.txt")
valid_text = read_data("./charlm.txt")
test_text = read_data("./test.txt")

train_set = np.array(text2vec(train_text, char_dict, max_word_len))
valid_set = np.array(text2vec(valid_text, char_dict, max_word_len))
test_set = np.array(text2vec(test_text, char_dict, max_word_len))

# Labels are next-word index in word_dict with the same length as inputs
train_label = np.array([word_dict[w] for w in train_text[1:]] + [word_dict[train_text[-1]]])
valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]])
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]])

category = {"tdata": train_set, "vdata": valid_set, "test": test_set,
"trlabel": train_label, "vlabel": valid_label, "tlabel": test_label}
torch.save(category, "cache/data_sets.pt")
else:
data_sets = torch.load("cache/data_sets.pt")
train_set = data_sets["tdata"]
valid_set = data_sets["vdata"]
test_set = data_sets["test"]
train_label = data_sets["trlabel"]
valid_label = data_sets["vlabel"]
test_label = data_sets["tlabel"]

DataTuple = namedtuple("DataTuple",
"train_input train_label valid_input valid_label test_input test_label")
data = DataTuple(train_input=train_set,
train_label=train_label,
valid_input=valid_set,
valid_label=valid_label,
test_input=test_set,
test_label=test_label)

print("Loaded data sets. Start building network.")

USE_GPU = True
cnn_batch_size = 700
lstm_seq_len = 35
lstm_batch_size = 20
# cnn_batch_size == lstm_seq_len * lstm_batch_size

net = charLM(char_embedding_dim,
word_embed_dim,
num_words,
len(char_dict),
use_gpu=USE_GPU)

for param in net.parameters():
nn.init.uniform(param.data, -0.05, 0.05)

Options = namedtuple("Options", [
"cnn_batch_size", "init_lr", "lstm_seq_len",
"max_word_len", "lstm_batch_size", "epochs",
"word_embed_dim"])
opt = Options(cnn_batch_size=lstm_seq_len * lstm_batch_size,
init_lr=1.0,
lstm_seq_len=lstm_seq_len,
max_word_len=max_word_len,
lstm_batch_size=lstm_batch_size,
epochs=35,
word_embed_dim=word_embed_dim)

print("Network built. Start training.")

# You can stop training anytime by "ctrl+C"
try:
train(net, data, opt)
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')

torch.save(net, "cache/net.pkl")
print("save net")

test(net, data, opt)

+ 42068
- 0
fastNLP/reproduction/Char-aware_NLM/train.txt
File diff suppressed because it is too large
View File


+ 82
- 0
fastNLP/reproduction/Char-aware_NLM/utilities.py View File

@@ -0,0 +1,82 @@
import torch
import torch.nn.functional as F


def batch_generator(x, batch_size):
# x: [num_words, in_channel, height, width]
# partitions x into batches
num_step = x.size()[0] // batch_size
for t in range(num_step):
yield x[t * batch_size:(t + 1) * batch_size]


def text2vec(words, char_dict, max_word_len):
""" Return list of list of int """
word_vec = []
for word in words:
vec = [char_dict[ch] for ch in word]
if len(vec) < max_word_len:
vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))]
vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]]
word_vec.append(vec)
return word_vec


def seq2vec(input_words, char_embedding, char_embedding_dim, char_table):
""" convert the input strings into character embeddings """
# input_words == list of string
# char_embedding == torch.nn.Embedding
# char_embedding_dim == int
# char_table == list of unique chars
# Returns: tensor of shape [len(input_words), char_embedding_dim, max_word_len+2]
max_word_len = max([len(word) for word in input_words])
print("max_word_len={}".format(max_word_len))
tensor_list = []

start_column = torch.ones(char_embedding_dim, 1)
end_column = torch.ones(char_embedding_dim, 1)

for word in input_words:
# convert string to word attention
word_encoding = char_embedding_lookup(word, char_embedding, char_table)
# add start and end columns
word_encoding = torch.cat([start_column, word_encoding, end_column], 1)
# zero-pad right columns
word_encoding = F.pad(word_encoding, (0, max_word_len - word_encoding.size()[1] + 2)).data
# create dimension
word_encoding = word_encoding.unsqueeze(0)

tensor_list.append(word_encoding)

return torch.cat(tensor_list, 0)


def read_data(file_name):
# Return: list of strings
with open(file_name, 'r') as f:
corpus = f.read().lower()
import re
corpus = re.sub(r"<unk>", "unk", corpus)
return corpus.split()


def get_char_dict(vocabulary):
# vocabulary == dict of (word, int)
# Return: dict of (char, int), starting from 1
char_dict = dict()
count = 1
for word in vocabulary:
for ch in word:
if ch not in char_dict:
char_dict[ch] = count
count += 1
return char_dict


def create_word_char_dict(*file_name):
text = []
for file in file_name:
text += read_data(file)
word_dict = {word: ix for ix, word in enumerate(set(text))}
char_dict = get_char_dict(word_dict)
return word_dict, char_dict

+ 3370
- 0
fastNLP/reproduction/Char-aware_NLM/valid.txt
File diff suppressed because it is too large
View File


+ 36
- 0
fastNLP/reproduction/HAN-document_classification/README.md View File

@@ -0,0 +1,36 @@
## Introduction
This is the implementation of [Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) paper in PyTorch.
* Dataset is 600k documents extracted from [Yelp 2018](https://www.yelp.com/dataset) customer reviews
* Use [NLTK](http://www.nltk.org/) and [Stanford CoreNLP](https://stanfordnlp.github.io/CoreNLP/) to tokenize documents and sentences
* Both CPU & GPU support
* The best accuracy is 71%, reaching the same performance in the paper

## Requirement
* python 3.6
* pytorch = 0.3.0
* numpy
* gensim
* nltk
* coreNLP

## Parameters
According to the paper and experiment, I set model parameters:
|word embedding dimension|GRU hidden size|GRU layer|word/sentence context vector dimension|
|---|---|---|---|
|200|50|1|100|

And the training parameters:
|Epoch|learning rate|momentum|batch size|
|---|---|---|---|
|3|0.01|0.9|64|

## Run
1. Prepare dataset. Download the [data set](https://www.yelp.com/dataset), and unzip the custom reviews as a file. Use preprocess.py to transform file into data set foe model input.
2. Train the model. Word enbedding of train data in 'yelp.word2vec'. The model will trained and autosaved in 'model.dict'
```
python train
```
3. Test the model.
```
python evaluate
```

BIN
fastNLP/reproduction/HAN-document_classification/data/test_samples.pkl View File


BIN
fastNLP/reproduction/HAN-document_classification/data/train_samples.pkl View File


BIN
fastNLP/reproduction/HAN-document_classification/data/yelp.word2vec View File


+ 45
- 0
fastNLP/reproduction/HAN-document_classification/evaluate.py View File

@@ -0,0 +1,45 @@
from model import *
from train import *


def evaluate(net, dataset, bactch_size=64, use_cuda=False):
dataloader = DataLoader(dataset, batch_size=bactch_size, collate_fn=collate, num_workers=0)
count = 0
if use_cuda:
net.cuda()
for i, batch_samples in enumerate(dataloader):
x, y = batch_samples
doc_list = []
for sample in x:
doc = []
for sent_vec in sample:
if use_cuda:
sent_vec = sent_vec.cuda()
doc.append(Variable(sent_vec, volatile=True))
doc_list.append(pack_sequence(doc))
if use_cuda:
y = y.cuda()
predicts = net(doc_list)
p, idx = torch.max(predicts, dim=1)
idx = idx.data
count += torch.sum(torch.eq(idx, y))
return count


if __name__ == '__main__':
'''
Evaluate the performance of models
'''
from gensim.models import Word2Vec

embed_model = Word2Vec.load('yelp.word2vec')
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
del embed_model

net = HAN(input_size=200, output_size=5,
word_hidden_size=50, word_num_layers=1, word_context_size=100,
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
net.load_state_dict(torch.load('models.dict'))
test_dataset = YelpDocSet('reviews', 199, 4, embedding)
correct = evaluate(net, test_dataset, True)
print('accuracy {}'.format(correct / len(test_dataset)))

+ 113
- 0
fastNLP/reproduction/HAN-document_classification/model.py View File

@@ -0,0 +1,113 @@
import torch
import torch.nn as nn
from torch.autograd import Variable


def pack_sequence(tensor_seq, padding_value=0.0):
if len(tensor_seq) <= 0:
return
length = [v.size(0) for v in tensor_seq]
max_len = max(length)
size = [len(tensor_seq), max_len]
size.extend(list(tensor_seq[0].size()[1:]))
ans = torch.Tensor(*size).fill_(padding_value)
if tensor_seq[0].data.is_cuda:
ans = ans.cuda()
ans = Variable(ans)
for i, v in enumerate(tensor_seq):
ans[i, :length[i], :] = v
return ans


class HAN(nn.Module):
def __init__(self, input_size, output_size,
word_hidden_size, word_num_layers, word_context_size,
sent_hidden_size, sent_num_layers, sent_context_size):
super(HAN, self).__init__()

self.word_layer = AttentionNet(input_size,
word_hidden_size,
word_num_layers,
word_context_size)
self.sent_layer = AttentionNet(2 * word_hidden_size,
sent_hidden_size,
sent_num_layers,
sent_context_size)
self.output_layer = nn.Linear(2 * sent_hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)

def forward(self, batch_doc):
# input is a sequence of matrix
doc_vec_list = []
for doc in batch_doc:
sent_mat = self.word_layer(doc) # doc's dim (num_sent, seq_len, word_dim)
doc_vec_list.append(sent_mat) # sent_mat's dim (num_sent, vec_dim)
doc_vec = self.sent_layer(pack_sequence(doc_vec_list))
output = self.softmax(self.output_layer(doc_vec))
return output


class AttentionNet(nn.Module):
def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size):
super(AttentionNet, self).__init__()

self.input_size = input_size
self.gru_hidden_size = gru_hidden_size
self.gru_num_layers = gru_num_layers
self.context_vec_size = context_vec_size

# Encoder
self.gru = nn.GRU(input_size=input_size,
hidden_size=gru_hidden_size,
num_layers=gru_num_layers,
batch_first=True,
bidirectional=True)
# Attention
self.fc = nn.Linear(2 * gru_hidden_size, context_vec_size)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(dim=1)
# context vector
self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1))
self.context_vec.data.uniform_(-0.1, 0.1)

def forward(self, inputs):
# GRU part
h_t, hidden = self.gru(inputs) # inputs's dim (batch_size, seq_len, word_dim)
u = self.tanh(self.fc(h_t))
# Attention part
alpha = self.softmax(torch.matmul(u, self.context_vec)) # u's dim (batch_size, seq_len, context_vec_size)
output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) # alpha's dim (batch_size, seq_len, 1)
return torch.squeeze(output, dim=2) # output's dim (batch_size, 2*hidden_size, 1)


if __name__ == '__main__':
'''
Test the models correctness
'''
import numpy as np

use_cuda = True
net = HAN(input_size=200, output_size=5,
word_hidden_size=50, word_num_layers=1, word_context_size=100,
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
criterion = nn.NLLLoss()
test_time = 10
batch_size = 64
if use_cuda:
net.cuda()
print('test training')
for step in range(test_time):
x_data = [torch.randn(np.random.randint(1, 10), 200, 200) for i in range(batch_size)]
y_data = torch.LongTensor([np.random.randint(0, 5) for i in range(batch_size)])
if use_cuda:
x_data = [x_i.cuda() for x_i in x_data]
y_data = y_data.cuda()
x = [Variable(x_i) for x_i in x_data]
y = Variable(y_data)
predict = net(x)
loss = criterion(predict, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss.data[0])

+ 50
- 0
fastNLP/reproduction/HAN-document_classification/preprocess.py View File

@@ -0,0 +1,50 @@
''''
Tokenize yelp dataset's documents using stanford core nlp
'''

import json
import os
import pickle

import nltk
from nltk.tokenize import stanford

input_filename = 'review.json'

# config for stanford core nlp
os.environ['JAVAHOME'] = 'D:\\java\\bin\\java.exe'
path_to_jar = 'E:\\College\\fudanNLP\\stanford-corenlp-full-2018-02-27\\stanford-corenlp-3.9.1.jar'
tokenizer = stanford.CoreNLPTokenizer()

in_dirname = 'review'
out_dirname = 'reviews'

f = open(input_filename, encoding='utf-8')
samples = []
j = 0
for i, line in enumerate(f.readlines()):
review = json.loads(line)
samples.append((review['stars'], review['text']))
if (i + 1) % 5000 == 0:
print(i)
pickle.dump(samples, open(in_dirname + '/samples%d.pkl' % j, 'wb'))
j += 1
samples = []
pickle.dump(samples, open(in_dirname + '/samples%d.pkl' % j, 'wb'))
# samples = pickle.load(open(out_dirname + '/samples0.pkl', 'rb'))
# print(samples[0])


for fn in os.listdir(in_dirname):
print(fn)
precessed = []
for stars, text in pickle.load(open(os.path.join(in_dirname, fn), 'rb')):
tokens = []
sents = nltk.tokenize.sent_tokenize(text)
for s in sents:
tokens.append(tokenizer.tokenize(s))
precessed.append((stars, tokens))
# print(tokens)
if len(precessed) % 100 == 0:
print(len(precessed))
pickle.dump(precessed, open(os.path.join(out_dirname, fn), 'wb'))

+ 171
- 0
fastNLP/reproduction/HAN-document_classification/train.py View File

@@ -0,0 +1,171 @@
import os
import pickle

import numpy as np
import torch
from model import *


class SentIter:
def __init__(self, dirname, count):
self.dirname = dirname
self.count = int(count)

def __iter__(self):
for f in os.listdir(self.dirname)[:self.count]:
with open(os.path.join(self.dirname, f), 'rb') as f:
for y, x in pickle.load(f):
for sent in x:
yield sent


def train_word_vec():
# load data
dirname = 'reviews'
sents = SentIter(dirname, 238)
# define models and train
model = models.Word2Vec(size=200, sg=0, workers=4, min_count=5)
model.build_vocab(sents)
model.train(sents, total_examples=model.corpus_count, epochs=10)
model.save('yelp.word2vec')
print(model.wv.similarity('woman', 'man'))
print(model.wv.similarity('nice', 'awful'))


class Embedding_layer:
def __init__(self, wv, vector_size):
self.wv = wv
self.vector_size = vector_size

def get_vec(self, w):
try:
v = self.wv[w]
except KeyError as e:
v = np.random.randn(self.vector_size)
return v


from torch.utils.data import DataLoader, Dataset


class YelpDocSet(Dataset):
def __init__(self, dirname, start_file, num_files, embedding):
self.dirname = dirname
self.num_files = num_files
self._files = os.listdir(dirname)[start_file:start_file + num_files]
self.embedding = embedding
self._cache = [(-1, None) for i in range(5)]

def get_doc(self, n):
file_id = n // 5000
idx = file_id % 5
if self._cache[idx][0] != file_id:
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
self._cache[idx] = (file_id, pickle.load(f))
y, x = self._cache[idx][1][n % 5000]
sents = []
for s_list in x:
sents.append(' '.join(s_list))
x = '\n'.join(sents)
return x, y - 1

def __len__(self):
return len(self._files) * 5000

def __getitem__(self, n):
file_id = n // 5000
idx = file_id % 5
if self._cache[idx][0] != file_id:
print('load {} to {}'.format(file_id, idx))
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
self._cache[idx] = (file_id, pickle.load(f))
y, x = self._cache[idx][1][n % 5000]
doc = []
for sent in x:
if len(sent) == 0:
continue
sent_vec = []
for word in sent:
vec = self.embedding.get_vec(word)
sent_vec.append(vec.tolist())
sent_vec = torch.Tensor(sent_vec)
doc.append(sent_vec)
if len(doc) == 0:
doc = [torch.zeros(1, 200)]
return doc, y - 1


def collate(iterable):
y_list = []
x_list = []
for x, y in iterable:
y_list.append(y)
x_list.append(x)
return x_list, torch.LongTensor(y_list)


def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False):
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
criterion = nn.NLLLoss()

dataloader = DataLoader(dataset,
batch_size=batch_size,
collate_fn=collate,
num_workers=0)
running_loss = 0.0

if use_cuda:
net.cuda()
print('start training')
for epoch in range(num_epoch):
for i, batch_samples in enumerate(dataloader):
x, y = batch_samples
doc_list = []
for sample in x:
doc = []
for sent_vec in sample:
if use_cuda:
sent_vec = sent_vec.cuda()
doc.append(Variable(sent_vec))
doc_list.append(pack_sequence(doc))
if use_cuda:
y = y.cuda()
y = Variable(y)
predict = net(doc_list)
loss = criterion(predict, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.data[0]
if i % print_size == print_size - 1:
print('{}, {}'.format(i + 1, running_loss / print_size))
running_loss = 0.0
torch.save(net.state_dict(), 'models.dict')
torch.save(net.state_dict(), 'models.dict')


if __name__ == '__main__':
'''
Train process
'''
from gensim.models import Word2Vec
from gensim import models

train_word_vec()

embed_model = Word2Vec.load('yelp.word2vec')
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
del embed_model
start_file = 0
dataset = YelpDocSet('reviews', start_file, 120 - start_file, embedding)
print('training data size {}'.format(len(dataset)))
net = HAN(input_size=200, output_size=5,
word_hidden_size=50, word_num_layers=1, word_context_size=100,
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
try:
net.load_state_dict(torch.load('models.dict'))
print("last time trained models has loaded")
except Exception:
print("cannot load models, train the inital models")

train(net, dataset, num_epoch=5, batch_size=64, use_cuda=True)

Loading…
Cancel
Save