Browse Source

TC/LSTM

LSTM、AWDLSTM和LSTM+self attention三个模型
tags/v0.4.10
lyhuang18 5 years ago
parent
commit
d05aca6da6
13 changed files with 827 additions and 52 deletions
  1. +80
    -0
      reproduction/text_classification/data/IMDBLoader.py
  2. +5
    -1
      reproduction/text_classification/data/MTL16Loader.py
  3. +99
    -0
      reproduction/text_classification/data/SSTLoader.py
  4. +60
    -51
      reproduction/text_classification/data/yelpLoader.py
  5. +31
    -0
      reproduction/text_classification/model/awd_lstm.py
  6. +86
    -0
      reproduction/text_classification/model/awdlstm_module.py
  7. +30
    -0
      reproduction/text_classification/model/lstm.py
  8. +35
    -0
      reproduction/text_classification/model/lstm_self_attention.py
  9. +99
    -0
      reproduction/text_classification/model/weight_drop.py
  10. BIN
      reproduction/text_classification/results_LSTM.xlsx
  11. +102
    -0
      reproduction/text_classification/train_awdlstm.py
  12. +99
    -0
      reproduction/text_classification/train_lstm.py
  13. +101
    -0
      reproduction/text_classification/train_lstm_att.py

+ 80
- 0
reproduction/text_classification/data/IMDBLoader.py View File

@@ -0,0 +1,80 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from typing import Union, Dict, List, Iterator
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import Vocabulary
from fastNLP import Const
# from reproduction.utils import check_dataloader_paths
from functools import partial

class IMDBLoader(DataSetLoader):
"""
读取IMDB数据集,DataSet包含以下fields:

words: list(str), 需要分类的文本
target: str, 文本的标签


"""

def __init__(self):
super(IMDBLoader, self).__init__()

def _load(self, path):
dataset = DataSet()
with open(path, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split('\t')
target = parts[0]
words = parts[1].lower().split()
dataset.append(Instance(words=words, target=target))
if len(dataset)==0:
raise RuntimeError(f"{path} has no valid data.")

return dataset
def process(self,
paths: Union[str, Dict[str, str]],
src_vocab_opt: VocabularyOption = None,
tgt_vocab_opt: VocabularyOption = None,
src_embed_opt: EmbeddingOption = None):
# paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset

datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False)

src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
src_vocab.from_dataset(datasets['train'], field_name='words')
src_vocab.index_dataset(*datasets.values(), field_name='words')

tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
tgt_vocab.from_dataset(datasets['train'], field_name='target')
tgt_vocab.index_dataset(*datasets.values(), field_name='target')

info.vocabs = {
"words": src_vocab,
"target": tgt_vocab
}

info.datasets = datasets

if src_embed_opt is not None:
embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab)
info.embeddings['words'] = embed

for name, dataset in info.datasets.items():
dataset.set_input("words")
dataset.set_target("target")

return info

+ 5
- 1
reproduction/text_classification/data/MTL16Loader.py View File

@@ -32,7 +32,7 @@ class MTL16Loader(DataSetLoader):
continue
parts = line.split('\t')
target = parts[0]
words = parts[1].split()
words = parts[1].lower().split()
dataset.append(Instance(words=words, target=target))
if len(dataset)==0:
raise RuntimeError(f"{path} has no valid data.")
@@ -72,4 +72,8 @@ class MTL16Loader(DataSetLoader):
embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab)
info.embeddings['words'] = embed

for name, dataset in info.datasets.items():
dataset.set_input("words")
dataset.set_target("target")

return info

+ 99
- 0
reproduction/text_classification/data/SSTLoader.py View File

@@ -0,0 +1,99 @@
from typing import Iterable
from nltk import Tree
from fastNLP.io.base_loader import DataInfo, DataSetLoader
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader


class SSTLoader(DataSetLoader):
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
DATA_DIR = 'sst/'

"""
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`

读取SST数据集, DataSet包含fields::

words: list(str) 需要分类的文本
target: str 文本的标签

数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip

:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
"""

def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree

tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
'3': 'positive', '4': 'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v

def _load(self, path):
"""

:param str path: 存储数据的路径
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
datas = []
for l in f:
datas.extend([(s, self.tag_v[t])
for s, t in self._get_one(l, self.subtree)])
ds = DataSet()
for words, tag in datas:
ds.append(Instance(words=words, target=tag))
return ds

@staticmethod
def _get_one(data, subtree):
tree = Tree.fromstring(data)
if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]

def process(self,
paths,
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
src_embed_op: EmbeddingOption = None):
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)

info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()
src_vocab.from_dataset(*_train_ds, field_name=input_name)
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
src_vocab.index_dataset(
*info.datasets.values(),
field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset(
*info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs = {
input_name: src_vocab,
target_name: tgt_vocab
}

if src_embed_op is not None:
src_embed_op.vocab = src_vocab
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
info.embeddings[input_name] = init_emb

for name, dataset in info.datasets.items():
dataset.set_input(input_name)
dataset.set_target(target_name)

return info


+ 60
- 51
reproduction/text_classification/data/yelpLoader.py View File

@@ -1,68 +1,77 @@
import ast
from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io import JsonLoader
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json
from typing import Union, Dict
from reproduction.Star_transformer.datasets import EmbedLoader
from reproduction.utils import check_dataloader_paths
from fastNLP.io.base_loader import DataSetLoader, DataInfo
from typing import Union, Dict, List, Iterator
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import Vocabulary
from fastNLP import Const
# from reproduction.utils import check_dataloader_paths
from functools import partial
import pandas as pd


class yelpLoader(JsonLoader):
class yelpLoader(DataSetLoader):
"""
读取Yelp数据集, DataSet包含fields:
review_id: str, 22 character unique review id
user_id: str, 22 character unique user id
business_id: str, 22 character business id
useful: int, number of useful votes received
funny: int, number of funny votes received
cool: int, number of cool votes received
date: str, date formatted YYYY-MM-DD
读取IMDB数据集,DataSet包含以下fields:

words: list(str), 需要分类的文本
target: str, 文本的标签
数据来源: https://www.yelp.com/dataset/download
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``


"""
def __init__(self, fine_grained=False):

def __init__(self):
super(yelpLoader, self).__init__()
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral',
'4.0': 'positive', '5.0': 'very positive'}
if not fine_grained:
tag_v['1.0'] = tag_v['2.0']
tag_v['5.0'] = tag_v['4.0']
self.fine_grained = fine_grained
self.tag_v = tag_v

def _load(self, path):
ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
d = ast.literal_eval(d)
d["words"] = d.pop("text").split()
d["target"] = self.tag_v[str(d.pop("stars"))]
ds.append(Instance(**d))
return ds
dataset = DataSet()
data = pd.read_csv(path, header=None, sep=",").values
for line in data:
target = str(line[0])
words = str(line[1]).lower().split()
dataset.append(Instance(words=words, target=target))
if len(dataset)==0:
raise RuntimeError(f"{path} has no valid data.")

def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None,
embed_opt: EmbeddingOption = None):
paths = check_dataloader_paths(paths)
return dataset
def process(self,
paths: Union[str, Dict[str, str]],
src_vocab_opt: VocabularyOption = None,
tgt_vocab_opt: VocabularyOption = None,
src_embed_opt: EmbeddingOption = None):
# paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt)
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset
vocab.from_dataset(dataset, field_name="words")
info.vocabs = vocab

datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False)

src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
src_vocab.from_dataset(datasets['train'], field_name='words')
src_vocab.index_dataset(*datasets.values(), field_name='words')

tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
tgt_vocab.from_dataset(datasets['train'], field_name='target')
tgt_vocab.index_dataset(*datasets.values(), field_name='target')

info.vocabs = {
"words": src_vocab,
"target": tgt_vocab
}

info.datasets = datasets
if embed_opt is not None:
embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab)

if src_embed_opt is not None:
embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab)
info.embeddings['words'] = embed
return info

for name, dataset in info.datasets.items():
dataset.set_input("words")
dataset.set_target("target")

return info

+ 31
- 0
reproduction/text_classification/model/awd_lstm.py View File

@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
from fastNLP.core.const import Const as C
from .awdlstm_module import LSTM
from fastNLP.modules import encoder
from fastNLP.modules.decoder.mlp import MLP


class AWDLSTMSentiment(nn.Module):
def __init__(self, init_embed,
num_classes,
hidden_dim=256,
num_layers=1,
nfc=128,
wdrop=0.5):
super(AWDLSTMSentiment,self).__init__()
self.embed = encoder.Embedding(init_embed)
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop)
self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes])

def forward(self, words):
x_emb = self.embed(words)
output, _ = self.lstm(x_emb)
output = self.mlp(output[:,-1,:])
return {C.OUTPUT: output}

def predict(self, words):
output = self(words)
_, predict = output[C.OUTPUT].max(dim=1)
return {C.OUTPUT: predict}


+ 86
- 0
reproduction/text_classification/model/awdlstm_module.py View File

@@ -0,0 +1,86 @@
"""
轻量封装的 Pytorch LSTM 模块.
可在 forward 时传入序列的长度, 自动对padding做合适的处理.
"""
__all__ = [
"LSTM"
]

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn

from fastNLP.modules.utils import initial_parameter
from torch import autograd
from .weight_drop import WeightDrop


class LSTM(nn.Module):
"""
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM`

LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
为1; 且可以应对DataParallel中LSTM的使用问题。

:param input_size: 输入 `x` 的特征维度
:param hidden_size: 隐状态 `h` 的特征维度.
:param num_layers: rnn的层数. Default: 1
:param dropout: 层间dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为
:(batch, seq, feature). Default: ``False``
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
"""
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, wdrop=0.5):
super(LSTM, self).__init__()
self.batch_first = batch_first
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
dropout=dropout, bidirectional=bidirectional)
self.lstm = WeightDrop(self.lstm, ['weight_hh_l0'], dropout=wdrop)
self.init_param()

def init_param(self):
for name, param in self.named_parameters():
if 'bias' in name:
# based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871
param.data.fill_(0)
n = param.size(0)
start, end = n // 4, n // 2
param.data[start:end].fill_(1)
else:
nn.init.xavier_uniform_(param)

def forward(self, x, seq_len=None, h0=None, c0=None):
"""

:param x: [batch, seq_len, input_size] 输入序列
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None``
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None``
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None``
:return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列
和 [batch, hidden_size*num_direction] 最后时刻隐状态.
"""
batch_size, max_len, _ = x.size()
if h0 is not None and c0 is not None:
hx = (h0, c0)
else:
hx = None
if seq_len is not None and not isinstance(x, rnn.PackedSequence):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
if self.batch_first:
x = x[sort_idx]
else:
x = x[:, sort_idx]
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
output, hx = self.lstm(x, hx) # -> [N,L,C]
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
if self.batch_first:
output = output[unsort_idx]
else:
output = output[:, unsort_idx]
else:
output, hx = self.lstm(x, hx)
return output, hx

+ 30
- 0
reproduction/text_classification/model/lstm.py View File

@@ -0,0 +1,30 @@
import torch
import torch.nn as nn
from fastNLP.core.const import Const as C
from fastNLP.modules.encoder.lstm import LSTM
from fastNLP.modules import encoder
from fastNLP.modules.decoder.mlp import MLP


class BiLSTMSentiment(nn.Module):
def __init__(self, init_embed,
num_classes,
hidden_dim=256,
num_layers=1,
nfc=128):
super(BiLSTMSentiment,self).__init__()
self.embed = encoder.Embedding(init_embed)
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True)
self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes])

def forward(self, words):
x_emb = self.embed(words)
output, _ = self.lstm(x_emb)
output = self.mlp(output[:,-1,:])
return {C.OUTPUT: output}

def predict(self, words):
output = self(words)
_, predict = output[C.OUTPUT].max(dim=1)
return {C.OUTPUT: predict}


+ 35
- 0
reproduction/text_classification/model/lstm_self_attention.py View File

@@ -0,0 +1,35 @@
import torch
import torch.nn as nn
from fastNLP.core.const import Const as C
from fastNLP.modules.encoder.lstm import LSTM
from fastNLP.modules import encoder
from fastNLP.modules.aggregator.attention import SelfAttention
from fastNLP.modules.decoder.mlp import MLP


class BiLSTM_SELF_ATTENTION(nn.Module):
def __init__(self, init_embed,
num_classes,
hidden_dim=256,
num_layers=1,
attention_unit=256,
attention_hops=1,
nfc=128):
super(BiLSTM_SELF_ATTENTION,self).__init__()
self.embed = encoder.Embedding(init_embed)
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True)
self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops)
self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes])

def forward(self, words):
x_emb = self.embed(words)
output, _ = self.lstm(x_emb)
after_attention, penalty = self.attention(output,words)
after_attention =after_attention.view(after_attention.size(0),-1)
output = self.mlp(after_attention)
return {C.OUTPUT: output}

def predict(self, words):
output = self(words)
_, predict = output[C.OUTPUT].max(dim=1)
return {C.OUTPUT: predict}

+ 99
- 0
reproduction/text_classification/model/weight_drop.py View File

@@ -0,0 +1,99 @@
import torch
from torch.nn import Parameter
from functools import wraps

class WeightDrop(torch.nn.Module):
def __init__(self, module, weights, dropout=0, variational=False):
super(WeightDrop, self).__init__()
self.module = module
self.weights = weights
self.dropout = dropout
self.variational = variational
self._setup()

def widget_demagnetizer_y2k_edition(*args, **kwargs):
# We need to replace flatten_parameters with a nothing function
# It must be a function rather than a lambda as otherwise pickling explodes
# We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
# (╯°□°)╯︵ ┻━┻
return

def _setup(self):
# Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
if issubclass(type(self.module), torch.nn.RNNBase):
self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

for name_w in self.weights:
print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
w = getattr(self.module, name_w)
del self.module._parameters[name_w]
self.module.register_parameter(name_w + '_raw', Parameter(w.data))

def _setweights(self):
for name_w in self.weights:
raw_w = getattr(self.module, name_w + '_raw')
w = None
if self.variational:
mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
if raw_w.is_cuda: mask = mask.cuda()
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
w = mask.expand_as(raw_w) * raw_w
else:
w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
setattr(self.module, name_w, w)

def forward(self, *args):
self._setweights()
return self.module.forward(*args)

if __name__ == '__main__':
import torch
from weight_drop import WeightDrop

# Input is (seq, batch, input)
x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda()
h0 = None

###

print('Testing WeightDrop')
print('=-=-=-=-=-=-=-=-=-=')

###

print('Testing WeightDrop with Linear')

lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9)
lin.cuda()
run1 = [x.sum() for x in lin(x).data]
run2 = [x.sum() for x in lin(x).data]

print('All items should be different')
print('Run 1:', run1)
print('Run 2:', run2)

assert run1[0] != run2[0]
assert run1[1] != run2[1]

print('---')

###

print('Testing WeightDrop with LSTM')

wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9)
wdrnn.cuda()

run1 = [x.sum() for x in wdrnn(x, h0)[0].data]
run2 = [x.sum() for x in wdrnn(x, h0)[0].data]

print('First timesteps should be equal, all others should differ')
print('Run 1:', run1)
print('Run 2:', run2)

# First time step, not influenced by hidden to hidden weights, should be equal
assert run1[0] == run2[0]
# Second step should not
assert run1[1] != run2[1]

print('---')

BIN
reproduction/text_classification/results_LSTM.xlsx View File


+ 102
- 0
reproduction/text_classification/train_awdlstm.py View File

@@ -0,0 +1,102 @@
# 这个模型需要在pytorch=0.4下运行,weight_drop不支持1.0

# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'


import torch.nn as nn

from data.SSTLoader import SSTLoader
from data.IMDBLoader import IMDBLoader
from data.yelpLoader import yelpLoader
from fastNLP.modules.encoder.embedding import StaticEmbedding
from model.awd_lstm import AWDLSTMSentiment

from fastNLP.core.const import Const as C
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP import Trainer, Tester
from torch.optim import Adam
from fastNLP.io.model_io import ModelLoader, ModelSaver

import argparse


class Config():
train_epoch= 10
lr=0.001

num_classes=2
hidden_dim=256
num_layers=1
nfc=128
wdrop=0.5

task_name = "IMDB"
datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51"
save_model_path="./result_IMDB_test/"
opt=Config


# load data
dataloaders = {
"IMDB":IMDBLoader(),
"YELP":yelpLoader(),
"SST-5":SSTLoader(subtree=True,fine_grained=True),
"SST-3":SSTLoader(subtree=True,fine_grained=False)
}

if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]:
raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']")

dataloader = dataloaders[opt.task_name]
datainfo=dataloader.process(opt.datapath)
# print(datainfo.datasets["train"])
# print(datainfo)


# define model
vocab=datainfo.vocabs['words']
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
model=AWDLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc, wdrop=opt.wdrop)


# define loss_function and metrics
loss=CrossEntropyLoss()
metrics=AccuracyMetric()
optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)


def train(datainfo, model, optimizer, loss, metrics, opt):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1,
n_epochs=opt.train_epoch, save_path=opt.save_model_path)
trainer.train()


def test(datainfo, metrics, opt):
# load model
model = ModelLoader.load_pytorch_model(opt.load_model_path)
print("model loaded!")

# Tester
tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0)
acc = tester.test()
print("acc=",acc)



parser = argparse.ArgumentParser()
parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model')


args = parser.parse_args()
if args.mode == 'train':
train(datainfo, model, optimizer, loss, metrics, opt)
elif args.mode == 'test':
test(datainfo, metrics, opt)
else:
print('no mode specified for model!')
parser.print_help()

+ 99
- 0
reproduction/text_classification/train_lstm.py View File

@@ -0,0 +1,99 @@
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'


import torch.nn as nn

from data.SSTLoader import SSTLoader
from data.IMDBLoader import IMDBLoader
from data.yelpLoader import yelpLoader
from fastNLP.modules.encoder.embedding import StaticEmbedding
from model.lstm import BiLSTMSentiment

from fastNLP.core.const import Const as C
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP import Trainer, Tester
from torch.optim import Adam
from fastNLP.io.model_io import ModelLoader, ModelSaver

import argparse


class Config():
train_epoch= 10
lr=0.001

num_classes=2
hidden_dim=256
num_layers=1
nfc=128

task_name = "IMDB"
datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51"
save_model_path="./result_IMDB_test/"
opt=Config


# load data
dataloaders = {
"IMDB":IMDBLoader(),
"YELP":yelpLoader(),
"SST-5":SSTLoader(subtree=True,fine_grained=True),
"SST-3":SSTLoader(subtree=True,fine_grained=False)
}

if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]:
raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']")

dataloader = dataloaders[opt.task_name]
datainfo=dataloader.process(opt.datapath)
# print(datainfo.datasets["train"])
# print(datainfo)


# define model
vocab=datainfo.vocabs['words']
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
model=BiLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc)


# define loss_function and metrics
loss=CrossEntropyLoss()
metrics=AccuracyMetric()
optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)


def train(datainfo, model, optimizer, loss, metrics, opt):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1,
n_epochs=opt.train_epoch, save_path=opt.save_model_path)
trainer.train()


def test(datainfo, metrics, opt):
# load model
model = ModelLoader.load_pytorch_model(opt.load_model_path)
print("model loaded!")

# Tester
tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0)
acc = tester.test()
print("acc=",acc)



parser = argparse.ArgumentParser()
parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model')


args = parser.parse_args()
if args.mode == 'train':
train(datainfo, model, optimizer, loss, metrics, opt)
elif args.mode == 'test':
test(datainfo, metrics, opt)
else:
print('no mode specified for model!')
parser.print_help()

+ 101
- 0
reproduction/text_classification/train_lstm_att.py View File

@@ -0,0 +1,101 @@
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'


import torch.nn as nn

from data.SSTLoader import SSTLoader
from data.IMDBLoader import IMDBLoader
from data.yelpLoader import yelpLoader
from fastNLP.modules.encoder.embedding import StaticEmbedding
from model.lstm_self_attention import BiLSTM_SELF_ATTENTION

from fastNLP.core.const import Const as C
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP import Trainer, Tester
from torch.optim import Adam
from fastNLP.io.model_io import ModelLoader, ModelSaver

import argparse


class Config():
train_epoch= 10
lr=0.001

num_classes=2
hidden_dim=256
num_layers=1
attention_unit=256
attention_hops=1
nfc=128

task_name = "IMDB"
datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51"
save_model_path="./result_IMDB_test/"
opt=Config


# load data
dataloaders = {
"IMDB":IMDBLoader(),
"YELP":yelpLoader(),
"SST-5":SSTLoader(subtree=True,fine_grained=True),
"SST-3":SSTLoader(subtree=True,fine_grained=False)
}

if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]:
raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']")

dataloader = dataloaders[opt.task_name]
datainfo=dataloader.process(opt.datapath)
# print(datainfo.datasets["train"])
# print(datainfo)


# define model
vocab=datainfo.vocabs['words']
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc)


# define loss_function and metrics
loss=CrossEntropyLoss()
metrics=AccuracyMetric()
optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)


def train(datainfo, model, optimizer, loss, metrics, opt):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1,
n_epochs=opt.train_epoch, save_path=opt.save_model_path)
trainer.train()


def test(datainfo, metrics, opt):
# load model
model = ModelLoader.load_pytorch_model(opt.load_model_path)
print("model loaded!")

# Tester
tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0)
acc = tester.test()
print("acc=",acc)



parser = argparse.ArgumentParser()
parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model')


args = parser.parse_args()
if args.mode == 'train':
train(datainfo, model, optimizer, loss, metrics, opt)
elif args.mode == 'test':
test(datainfo, metrics, opt)
else:
print('no mode specified for model!')
parser.print_help()

Loading…
Cancel
Save