Browse Source

-update DPCNN & train script

-use spacy tokenizer for yelp data
-add set_rng_seed
tags/v0.4.10
yunfan 5 years ago
parent
commit
c5fc29dfef
5 changed files with 94 additions and 49 deletions
  1. +4
    -5
      fastNLP/modules/aggregator/attention.py
  2. +14
    -4
      reproduction/text_classification/data/yelpLoader.py
  3. +14
    -8
      reproduction/text_classification/model/dpcnn.py
  4. +51
    -32
      reproduction/text_classification/train_dpcnn.py
  5. +11
    -0
      reproduction/text_classification/utils/util_init.py

+ 4
- 5
fastNLP/modules/aggregator/attention.py View File

@@ -19,7 +19,7 @@ class DotAttention(nn.Module):
补上文档 补上文档
""" """
def __init__(self, key_size, value_size, dropout=0):
def __init__(self, key_size, value_size, dropout=0.0):
super(DotAttention, self).__init__() super(DotAttention, self).__init__()
self.key_size = key_size self.key_size = key_size
self.value_size = value_size self.value_size = value_size
@@ -37,7 +37,7 @@ class DotAttention(nn.Module):
""" """
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
if mask_out is not None: if mask_out is not None:
output.masked_fill_(mask_out, -1e8)
output.masked_fill_(mask_out, -1e18)
output = self.softmax(output) output = self.softmax(output)
output = self.drop(output) output = self.drop(output)
return torch.matmul(output, V) return torch.matmul(output, V)
@@ -67,9 +67,8 @@ class MultiHeadAttention(nn.Module):
self.k_in = nn.Linear(input_size, in_size) self.k_in = nn.Linear(input_size, in_size)
self.v_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size)
# follow the paper, do not apply dropout within dot-product # follow the paper, do not apply dropout within dot-product
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0)
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout)
self.out = nn.Linear(value_size * num_head, input_size) self.out = nn.Linear(value_size * num_head, input_size)
self.drop = TimestepDropout(dropout)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
@@ -105,7 +104,7 @@ class MultiHeadAttention(nn.Module):
# concat all heads, do output linear # concat all heads, do output linear
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1)
output = self.drop(self.out(atte))
output = self.out(atte)
return output return output






+ 14
- 4
reproduction/text_classification/data/yelpLoader.py View File

@@ -8,11 +8,20 @@ from fastNLP.io.base_loader import DataInfo
from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json from fastNLP.io.file_reader import _read_json
from typing import Union, Dict from typing import Union, Dict
from reproduction.Star_transformer.datasets import EmbedLoader
from reproduction.utils import check_dataloader_paths from reproduction.utils import check_dataloader_paths




def clean_str(sentence, char_lower=False):
def get_tokenizer():
try:
import spacy
en = spacy.load('en')
print('use spacy tokenizer')
return lambda x: [w.text for w in en.tokenizer(x)]
except Exception as e:
print('use raw tokenizer')
return lambda x: x.split()

def clean_str(sentence, tokenizer, char_lower=False):
""" """
heavily borrowed from github heavily borrowed from github
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
@@ -23,7 +32,7 @@ def clean_str(sentence, char_lower=False):
sentence = sentence.lower() sentence = sentence.lower()
import re import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
words = sentence.split()
words = tokenizer(sentence)
words_collection = [] words_collection = []
for word in words: for word in words:
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']:
@@ -65,6 +74,7 @@ class yelpLoader(JsonLoader):
self.fine_grained = fine_grained self.fine_grained = fine_grained
self.tag_v = tag_v self.tag_v = tag_v
self.lower = lower self.lower = lower
self.tokenizer = get_tokenizer()


''' '''
def _load_json(self, path): def _load_json(self, path):
@@ -109,7 +119,7 @@ class yelpLoader(JsonLoader):
all_count += 1 all_count += 1
if len(row) == 2: if len(row) == 2:
target = self.tag_v[row[0] + ".0"] target = self.tag_v[row[0] + ".0"]
words = clean_str(row[1], self.lower)
words = clean_str(row[1], self.tokenizer, self.lower)
if len(words) != 0: if len(words) != 0:
ds.append(Instance(words=words, target=target)) ds.append(Instance(words=words, target=target))
real_count += 1 real_count += 1


+ 14
- 8
reproduction/text_classification/model/dpcnn.py View File

@@ -3,22 +3,27 @@ import torch.nn as nn
from fastNLP.modules.utils import get_embeddings from fastNLP.modules.utils import get_embeddings
from fastNLP.core import Const as C from fastNLP.core import Const as C



class DPCNN(nn.Module): class DPCNN(nn.Module):
def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1):
def __init__(self, init_embed, num_cls, n_filters=256,
kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1):
super().__init__() super().__init__()
self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9])
self.region_embed = RegionEmbedding(
init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5])
embed_dim = self.region_embed.embedding_dim embed_dim = self.region_embed.embedding_dim
self.conv_list = nn.ModuleList() self.conv_list = nn.ModuleList()
for i in range(n_layers): for i in range(n_layers):
self.conv_list.append(nn.Sequential( self.conv_list.append(nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2),
nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2),
nn.Conv1d(n_filters, n_filters, kernel_size,
padding=kernel_size//2),
nn.Conv1d(n_filters, n_filters, kernel_size,
padding=kernel_size//2),
)) ))
self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
self.embed_drop = nn.Dropout(embed_dropout) self.embed_drop = nn.Dropout(embed_dropout)
self.classfier = nn.Sequential( self.classfier = nn.Sequential(
nn.Dropout(dropout),
nn.Dropout(cls_dropout),
nn.Linear(n_filters, num_cls), nn.Linear(n_filters, num_cls),
) )
self.reset_parameters() self.reset_parameters()
@@ -57,7 +62,8 @@ class RegionEmbedding(nn.Module):
super().__init__() super().__init__()
if kernel_sizes is None: if kernel_sizes is None:
kernel_sizes = [5, 9] kernel_sizes = [5, 9]
assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)'
assert isinstance(
kernel_sizes, list), 'kernel_sizes should be List(int)'
self.embed = get_embeddings(init_embed) self.embed = get_embeddings(init_embed)
try: try:
embed_dim = self.embed.embedding_dim embed_dim = self.embed.embedding_dim
@@ -69,14 +75,14 @@ class RegionEmbedding(nn.Module):
nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2),
)) ))
self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1)
for _ in range(len(kernel_sizes) + 1)])
for _ in range(len(kernel_sizes))])
self.embedding_dim = embed_dim self.embedding_dim = embed_dim


def forward(self, x): def forward(self, x):
x = self.embed(x) x = self.embed(x)
x = x.transpose(1, 2) x = x.transpose(1, 2)
# B, C, L # B, C, L
out = self.linears[0](x)
out = 0
for conv, fc in zip(self.region_embeds, self.linears[1:]): for conv, fc in zip(self.region_embeds, self.linears[1:]):
conv_i = conv(x) conv_i = conv(x)
out = out + fc(conv_i) out = out + fc(conv_i)


+ 51
- 32
reproduction/text_classification/train_dpcnn.py View File

@@ -1,40 +1,44 @@
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径


from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.cuda import torch.cuda
from fastNLP.core.utils import cache_results
from torch.optim import SGD from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer
from fastNLP import CrossEntropyLoss, AccuracyMetric from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding
from reproduction.text_classification.model.dpcnn import DPCNN from reproduction.text_classification.model.dpcnn import DPCNN
from .data.yelpLoader import yelpLoader
from fastNLP.io.dataset_loader import SSTLoader
from data.yelpLoader import yelpLoader
import torch.nn as nn import torch.nn as nn
from fastNLP.core import LRScheduler from fastNLP.core import LRScheduler
from fastNLP.core.const import Const as C from fastNLP.core.const import Const as C
import sys
from fastNLP.core.vocabulary import VocabularyOption
from utils.util_init import set_rng_seeds
import os import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


sys.path.append('../..')



# hyper # hyper


class Config(): class Config():
model_dir_or_name = "en-base-uncased"
embedding_grad = False,
seed = 12345
model_dir_or_name = "dpcnn-yelp-p"
embedding_grad = True
train_epoch = 30 train_epoch = 30
batch_size = 100 batch_size = 100
num_classes = 2 num_classes = 2
task = "yelp_p" task = "yelp_p"
#datadir = '/remote-home/yfshao/workdir/datasets/SST' #datadir = '/remote-home/yfshao/workdir/datasets/SST'
datadir = '/remote-home/ygwang/yelp_polarity'
datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity'
#datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"}
datafile = {"train": "train.csv", "test": "test.csv"} datafile = {"train": "train.csv", "test": "test.csv"}
lr = 1e-3 lr = 1e-3
src_vocab_op = VocabularyOption()
embed_dropout = 0.3
cls_dropout = 0.1
weight_decay = 1e-4


def __init__(self): def __init__(self):
self.datapath = {k: os.path.join(self.datadir, v) self.datapath = {k: os.path.join(self.datadir, v)
@@ -43,15 +47,23 @@ class Config():


ops = Config() ops = Config()


set_rng_seeds(ops.seed)
print('RNG SEED: {}'.format(ops.seed))


# 1.task相关信息:利用dataloader载入dataInfo # 1.task相关信息:利用dataloader载入dataInfo


#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train'])
datainfo = yelpLoader(fine_grained=True, lower=True).process(
paths=ops.datapath, train_ds=['train'])
print(len(datainfo.datasets['train']))
print(len(datainfo.datasets['test']))

@cache_results(ops.model_dir_or_name+'-data-cache')
def load_data():
datainfo = yelpLoader(fine_grained=True, lower=True).process(
paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op)
for ds in datainfo.datasets.values():
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
ds.set_input(C.INPUT, C.INPUT_LEN)
ds.set_target(C.TARGET)
return datainfo

datainfo = load_data()


# 2.或直接复用fastNLP的模型 # 2.或直接复用fastNLP的模型


@@ -59,43 +71,50 @@ vocab = datainfo.vocabs['words']
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
#embedding = StaticEmbedding(vocab) #embedding = StaticEmbedding(vocab)
embedding = StaticEmbedding( embedding = StaticEmbedding(
vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad,
normalize=False
)

print(len(datainfo.datasets['train']))
print(len(datainfo.datasets['test']))
print(datainfo.datasets['train'][0])


print(len(vocab)) print(len(vocab))
print(len(datainfo.vocabs['target'])) print(len(datainfo.vocabs['target']))


model = DPCNN(init_embed=embedding, num_cls=ops.num_classes)

model = DPCNN(init_embed=embedding, num_cls=ops.num_classes,
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
print(model)


# 3. 声明loss,metric,optimizer # 3. 声明loss,metric,optimizer
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
lr=ops.lr, momentum=0.9, weight_decay=0)
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)


callbacks = [] callbacks = []
callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
# callbacks.append
# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
# ops.train_epoch * 0.8 else ops.lr * 0.1))
# )

# callbacks.append(
# FitlogCallback(data=datainfo.datasets, verbose=1)
# )


device = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


print(device) print(device)


for ds in datainfo.datasets.values():
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
ds.set_input(C.INPUT, C.INPUT_LEN)
ds.set_target(C.TARGET)


# 4.定义train方法 # 4.定义train方法
def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=[metrics],
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=num_epochs)
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=[metric],
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=ops.train_epoch, num_workers=4)


print(trainer.train())




if __name__ == "__main__": if __name__ == "__main__":
train(model, datainfo, loss, metric, optimizer)
print(trainer.train())

+ 11
- 0
reproduction/text_classification/utils/util_init.py View File

@@ -0,0 +1,11 @@
import numpy
import torch
import random


def set_rng_seeds(seed):
random.seed(seed)
numpy.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# print('RNG_SEED {}'.format(seed))

Loading…
Cancel
Save