Browse Source

Merge pull request #10 from SrWYG/dev0.5.0

Dev0.5.0
tags/v0.4.10
lyhuang18 GitHub 5 years ago
parent
commit
48d6f38f23
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 346 additions and 12 deletions
  1. +22
    -0
      reproduction/text_classification/README.md
  2. +97
    -8
      reproduction/text_classification/data/sstLoader.py
  3. +6
    -0
      reproduction/text_classification/data/yelpLoader.py
  4. +109
    -0
      reproduction/text_classification/model/HAN.py
  5. +109
    -0
      reproduction/text_classification/train_HAN.py
  6. +3
    -4
      reproduction/text_classification/train_char_cnn.py

+ 22
- 0
reproduction/text_classification/README.md View File

@@ -0,0 +1,22 @@
# text_classification任务模型复现
这里使用fastNLP复现以下模型:
char_cnn :论文链接[Character-level Convolutional Networks for Text Classification](https://arxiv.org/pdf/1509.01626v3.pdf)
dpcnn:论文链接[Deep Pyramid Convolutional Neural Networks for TextCategorization](https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf)
HAN:论文链接[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf)
#待补充
awd_lstm:
lstm_self_attention(BCN?):
awd-sltm:

# 数据集及复现结果汇总

使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果)

model name | yelp_p | sst-2|IMDB|
:---: | :---: | :---: | :---:
char_cnn | 93.80/95.12 | - |- |
dpcnn | 95.50/97.36 | - |- |
HAN |- | - |-|
BCN| - |- |-|
awd-lstm| - |- |-|


+ 97
- 8
reproduction/text_classification/data/sstLoader.py View File

@@ -1,13 +1,102 @@
import csv
from typing import Iterable
from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataInfo,DataSetLoader
from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json
from nltk import Tree
from fastNLP.io.base_loader import DataInfo, DataSetLoader
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
import csv
from typing import Union, Dict
from reproduction.Star_transformer.datasets import EmbedLoader
from reproduction.utils import check_dataloader_paths

class SSTLoader(DataSetLoader):
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
DATA_DIR = 'sst/'

"""
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`

读取SST数据集, DataSet包含fields::

words: list(str) 需要分类的文本
target: str 文本的标签

数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip

:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
"""

def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree

tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v

def _load(self, path):
"""

:param str path: 存储数据的路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
datas = []
for l in f:
datas.extend([(s, self.tag_v[t])
for s, t in self._get_one(l, self.subtree)])
ds = DataSet()
for words, tag in datas:
ds.append(Instance(words=words, target=tag))
return ds

@staticmethod
def _get_one(data, subtree):
tree = Tree.fromstring(data)
if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]

def process(self,
paths,
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
src_embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)

info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()
src_vocab.from_dataset(*_train_ds, field_name=input_name)
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
src_vocab.index_dataset(
*info.datasets.values(),
field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset(
*info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs = {
input_name: src_vocab,
target_name: tgt_vocab
}

if src_embed_op is not None:
src_embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
info.embeddings[input_name] = init_emb

for name, dataset in info.datasets.items():
dataset.set_input(input_name)
dataset.set_target(target_name)

return info

class sst2Loader(DataSetLoader):
'''


+ 6
- 0
reproduction/text_classification/data/yelpLoader.py View File

@@ -184,6 +184,12 @@ class yelpLoader(DataSetLoader):

info.vocabs[target_name]=tgt_vocab

info.datasets['train'],info.datasets['dev']=info.datasets['train'].split(0.1, shuffle=False)

for name, dataset in info.datasets.items():
dataset.set_input("words")
dataset.set_target("target")

return info

if __name__=="__main__":


+ 109
- 0
reproduction/text_classification/model/HAN.py View File

@@ -0,0 +1,109 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
from fastNLP.modules.utils import get_embeddings
from fastNLP.core import Const as C


def pack_sequence(tensor_seq, padding_value=0.0):
if len(tensor_seq) <= 0:
return
length = [v.size(0) for v in tensor_seq]
max_len = max(length)
size = [len(tensor_seq), max_len]
size.extend(list(tensor_seq[0].size()[1:]))
ans = torch.Tensor(*size).fill_(padding_value)
if tensor_seq[0].data.is_cuda:
ans = ans.cuda()
ans = Variable(ans)
for i, v in enumerate(tensor_seq):
ans[i, :length[i], :] = v
return ans


class HANCLS(nn.Module):
def __init__(self, init_embed, num_cls):
super(HANCLS, self).__init__()

self.embed = get_embeddings(init_embed)
self.han = HAN(input_size=300,
output_size=num_cls,
word_hidden_size=50, word_num_layers=1, word_context_size=100,
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100
)

def forward(self, input_sents):
# input_sents [B, num_sents, seq-len] dtype long
# target
B, num_sents, seq_len = input_sents.size()
input_sents = input_sents.view(-1, seq_len) # flat
words_embed = self.embed(input_sents) # should be [B*num-sent, seqlen , word-dim]
words_embed = words_embed.view(B, num_sents, seq_len, -1) # recover # [B, num-sent, seqlen , word-dim]
out = self.han(words_embed)

return {C.OUTPUT: out}

def predict(self, input_sents):
x = self.forward(input_sents)[C.OUTPUT]
return {C.OUTPUT: torch.argmax(x, 1)}


class HAN(nn.Module):
def __init__(self, input_size, output_size,
word_hidden_size, word_num_layers, word_context_size,
sent_hidden_size, sent_num_layers, sent_context_size):
super(HAN, self).__init__()

self.word_layer = AttentionNet(input_size,
word_hidden_size,
word_num_layers,
word_context_size)
self.sent_layer = AttentionNet(2 * word_hidden_size,
sent_hidden_size,
sent_num_layers,
sent_context_size)
self.output_layer = nn.Linear(2 * sent_hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)

def forward(self, batch_doc):
# input is a sequence of matrix
doc_vec_list = []
for doc in batch_doc:
sent_mat = self.word_layer(doc) # doc's dim (num_sent, seq_len, word_dim)
doc_vec_list.append(sent_mat) # sent_mat's dim (num_sent, vec_dim)
doc_vec = self.sent_layer(pack_sequence(doc_vec_list))
output = self.softmax(self.output_layer(doc_vec))
return output


class AttentionNet(nn.Module):
def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size):
super(AttentionNet, self).__init__()

self.input_size = input_size
self.gru_hidden_size = gru_hidden_size
self.gru_num_layers = gru_num_layers
self.context_vec_size = context_vec_size

# Encoder
self.gru = nn.GRU(input_size=input_size,
hidden_size=gru_hidden_size,
num_layers=gru_num_layers,
batch_first=True,
bidirectional=True)
# Attention
self.fc = nn.Linear(2 * gru_hidden_size, context_vec_size)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(dim=1)
# context vector
self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1))
self.context_vec.data.uniform_(-0.1, 0.1)

def forward(self, inputs):
# GRU part
h_t, hidden = self.gru(inputs) # inputs's dim (batch_size, seq_len, word_dim)
u = self.tanh(self.fc(h_t))
# Attention part
alpha = self.softmax(torch.matmul(u, self.context_vec)) # u's dim (batch_size, seq_len, context_vec_size)
output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) # alpha's dim (batch_size, seq_len, 1)
return torch.squeeze(output, dim=2) # output's dim (batch_size, 2*hidden_size, 1)

+ 109
- 0
reproduction/text_classification/train_HAN.py View File

@@ -0,0 +1,109 @@
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径

import os
import sys
sys.path.append('../../')
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

from fastNLP.core.const import Const as C
from fastNLP.core import LRScheduler
import torch.nn as nn
from fastNLP.io.dataset_loader import SSTLoader
from reproduction.text_classification.data.yelpLoader import yelpLoader
from reproduction.text_classification.model.HAN import HANCLS
from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.core.trainer import Trainer
from torch.optim import SGD
import torch.cuda
from torch.optim.lr_scheduler import CosineAnnealingLR


##hyper

class Config():
model_dir_or_name = "en-base-uncased"
embedding_grad = False,
train_epoch = 30
batch_size = 100
num_classes = 5
task = "yelp"
#datadir = '/remote-home/lyli/fastNLP/yelp_polarity/'
datadir = '/remote-home/ygwang/yelp_polarity/'
datafile = {"train": "train.csv", "test": "test.csv"}
lr = 1e-3

def __init__(self):
self.datapath = {k: os.path.join(self.datadir, v)
for k, v in self.datafile.items()}


ops = Config()

##1.task相关信息:利用dataloader载入dataInfo

datainfo = yelpLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train'])
print(len(datainfo.datasets['train']))
print(len(datainfo.datasets['test']))


# post process
def make_sents(words):
sents = [words]
return sents


for dataset in datainfo.datasets.values():
dataset.apply_field(make_sents, field_name='words', new_field_name='input_sents')

datainfo = datainfo
datainfo.datasets['train'].set_input('input_sents')
datainfo.datasets['test'].set_input('input_sents')
datainfo.datasets['train'].set_target('target')
datainfo.datasets['test'].set_target('target')

## 2.或直接复用fastNLP的模型

vocab = datainfo.vocabs['words']
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
embedding = StaticEmbedding(vocab)

print(len(vocab))
print(len(datainfo.vocabs['target']))

# model = DPCNN(init_embed=embedding, num_cls=ops.num_classes)
model = HANCLS(init_embed=embedding, num_cls=ops.num_classes)

## 3. 声明loss,metric,optimizer
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
lr=ops.lr, momentum=0.9, weight_decay=0)

callbacks = []
callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(device)

for ds in datainfo.datasets.values():
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
ds.set_input(C.INPUT, C.INPUT_LEN)
ds.set_target(C.TARGET)


## 4.定义train方法
def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=[metrics], dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=num_epochs)

print(trainer.train())


if __name__ == "__main__":
train(model, datainfo, loss, metric, optimizer)

+ 3
- 4
reproduction/text_classification/train_char_cnn.py View File

@@ -7,7 +7,6 @@ import sys
sys.path.append('../..')
from fastNLP.core.const import Const as C
import torch.nn as nn
from fastNLP.io.dataset_loader import SSTLoader
from data.yelpLoader import yelpLoader
from data.sstLoader import sst2Loader
from data.IMDBLoader import IMDBLoader
@@ -107,9 +106,9 @@ ops=Config


##1.task相关信息:利用dataloader载入dataInfo
dataloader=sst2Loader()
dataloader=IMDBLoader()
#dataloader=yelpLoader(fine_grained=True)
#dataloader=sst2Loader()
#dataloader=IMDBLoader()
dataloader=yelpLoader(fine_grained=True)
datainfo=dataloader.process(ops.datapath,char_level_op=True)
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
ops.number_of_characters=len(char_vocab)


Loading…
Cancel
Save