Browse Source

-update DPCNN & train script

-use spacy tokenizer for yelp data
-add set_rng_seed
tags/v0.4.10
yunfan 5 years ago
parent
commit
c5fc29dfef
5 changed files with 94 additions and 49 deletions
  1. +4
    -5
      fastNLP/modules/aggregator/attention.py
  2. +14
    -4
      reproduction/text_classification/data/yelpLoader.py
  3. +14
    -8
      reproduction/text_classification/model/dpcnn.py
  4. +51
    -32
      reproduction/text_classification/train_dpcnn.py
  5. +11
    -0
      reproduction/text_classification/utils/util_init.py

+ 4
- 5
fastNLP/modules/aggregator/attention.py View File

@@ -19,7 +19,7 @@ class DotAttention(nn.Module):
补上文档
"""
def __init__(self, key_size, value_size, dropout=0):
def __init__(self, key_size, value_size, dropout=0.0):
super(DotAttention, self).__init__()
self.key_size = key_size
self.value_size = value_size
@@ -37,7 +37,7 @@ class DotAttention(nn.Module):
"""
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
if mask_out is not None:
output.masked_fill_(mask_out, -1e8)
output.masked_fill_(mask_out, -1e18)
output = self.softmax(output)
output = self.drop(output)
return torch.matmul(output, V)
@@ -67,9 +67,8 @@ class MultiHeadAttention(nn.Module):
self.k_in = nn.Linear(input_size, in_size)
self.v_in = nn.Linear(input_size, in_size)
# follow the paper, do not apply dropout within dot-product
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0)
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout)
self.out = nn.Linear(value_size * num_head, input_size)
self.drop = TimestepDropout(dropout)
self.reset_parameters()
def reset_parameters(self):
@@ -105,7 +104,7 @@ class MultiHeadAttention(nn.Module):
# concat all heads, do output linear
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1)
output = self.drop(self.out(atte))
output = self.out(atte)
return output




+ 14
- 4
reproduction/text_classification/data/yelpLoader.py View File

@@ -8,11 +8,20 @@ from fastNLP.io.base_loader import DataInfo
from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json
from typing import Union, Dict
from reproduction.Star_transformer.datasets import EmbedLoader
from reproduction.utils import check_dataloader_paths


def clean_str(sentence, char_lower=False):
def get_tokenizer():
try:
import spacy
en = spacy.load('en')
print('use spacy tokenizer')
return lambda x: [w.text for w in en.tokenizer(x)]
except Exception as e:
print('use raw tokenizer')
return lambda x: x.split()

def clean_str(sentence, tokenizer, char_lower=False):
"""
heavily borrowed from github
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
@@ -23,7 +32,7 @@ def clean_str(sentence, char_lower=False):
sentence = sentence.lower()
import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
words = sentence.split()
words = tokenizer(sentence)
words_collection = []
for word in words:
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']:
@@ -65,6 +74,7 @@ class yelpLoader(JsonLoader):
self.fine_grained = fine_grained
self.tag_v = tag_v
self.lower = lower
self.tokenizer = get_tokenizer()

'''
def _load_json(self, path):
@@ -109,7 +119,7 @@ class yelpLoader(JsonLoader):
all_count += 1
if len(row) == 2:
target = self.tag_v[row[0] + ".0"]
words = clean_str(row[1], self.lower)
words = clean_str(row[1], self.tokenizer, self.lower)
if len(words) != 0:
ds.append(Instance(words=words, target=target))
real_count += 1


+ 14
- 8
reproduction/text_classification/model/dpcnn.py View File

@@ -3,22 +3,27 @@ import torch.nn as nn
from fastNLP.modules.utils import get_embeddings
from fastNLP.core import Const as C


class DPCNN(nn.Module):
def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1):
def __init__(self, init_embed, num_cls, n_filters=256,
kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1):
super().__init__()
self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9])
self.region_embed = RegionEmbedding(
init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5])
embed_dim = self.region_embed.embedding_dim
self.conv_list = nn.ModuleList()
for i in range(n_layers):
self.conv_list.append(nn.Sequential(
nn.ReLU(),
nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2),
nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2),
nn.Conv1d(n_filters, n_filters, kernel_size,
padding=kernel_size//2),
nn.Conv1d(n_filters, n_filters, kernel_size,
padding=kernel_size//2),
))
self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
self.embed_drop = nn.Dropout(embed_dropout)
self.classfier = nn.Sequential(
nn.Dropout(dropout),
nn.Dropout(cls_dropout),
nn.Linear(n_filters, num_cls),
)
self.reset_parameters()
@@ -57,7 +62,8 @@ class RegionEmbedding(nn.Module):
super().__init__()
if kernel_sizes is None:
kernel_sizes = [5, 9]
assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)'
assert isinstance(
kernel_sizes, list), 'kernel_sizes should be List(int)'
self.embed = get_embeddings(init_embed)
try:
embed_dim = self.embed.embedding_dim
@@ -69,14 +75,14 @@ class RegionEmbedding(nn.Module):
nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2),
))
self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1)
for _ in range(len(kernel_sizes) + 1)])
for _ in range(len(kernel_sizes))])
self.embedding_dim = embed_dim

def forward(self, x):
x = self.embed(x)
x = x.transpose(1, 2)
# B, C, L
out = self.linears[0](x)
out = 0
for conv, fc in zip(self.region_embeds, self.linears[1:]):
conv_i = conv(x)
out = out + fc(conv_i)


+ 51
- 32
reproduction/text_classification/train_dpcnn.py View File

@@ -1,40 +1,44 @@
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径

from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.cuda
from fastNLP.core.utils import cache_results
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from fastNLP.core.trainer import Trainer
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding
from reproduction.text_classification.model.dpcnn import DPCNN
from .data.yelpLoader import yelpLoader
from fastNLP.io.dataset_loader import SSTLoader
from data.yelpLoader import yelpLoader
import torch.nn as nn
from fastNLP.core import LRScheduler
from fastNLP.core.const import Const as C
import sys
from fastNLP.core.vocabulary import VocabularyOption
from utils.util_init import set_rng_seeds
import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

sys.path.append('../..')


# hyper

class Config():
model_dir_or_name = "en-base-uncased"
embedding_grad = False,
seed = 12345
model_dir_or_name = "dpcnn-yelp-p"
embedding_grad = True
train_epoch = 30
batch_size = 100
num_classes = 2
task = "yelp_p"
#datadir = '/remote-home/yfshao/workdir/datasets/SST'
datadir = '/remote-home/ygwang/yelp_polarity'
datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity'
#datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"}
datafile = {"train": "train.csv", "test": "test.csv"}
lr = 1e-3
src_vocab_op = VocabularyOption()
embed_dropout = 0.3
cls_dropout = 0.1
weight_decay = 1e-4

def __init__(self):
self.datapath = {k: os.path.join(self.datadir, v)
@@ -43,15 +47,23 @@ class Config():

ops = Config()

set_rng_seeds(ops.seed)
print('RNG SEED: {}'.format(ops.seed))

# 1.task相关信息:利用dataloader载入dataInfo

#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train'])
datainfo = yelpLoader(fine_grained=True, lower=True).process(
paths=ops.datapath, train_ds=['train'])
print(len(datainfo.datasets['train']))
print(len(datainfo.datasets['test']))

@cache_results(ops.model_dir_or_name+'-data-cache')
def load_data():
datainfo = yelpLoader(fine_grained=True, lower=True).process(
paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op)
for ds in datainfo.datasets.values():
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
ds.set_input(C.INPUT, C.INPUT_LEN)
ds.set_target(C.TARGET)
return datainfo

datainfo = load_data()

# 2.或直接复用fastNLP的模型

@@ -59,43 +71,50 @@ vocab = datainfo.vocabs['words']
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
#embedding = StaticEmbedding(vocab)
embedding = StaticEmbedding(
vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad,
normalize=False
)

print(len(datainfo.datasets['train']))
print(len(datainfo.datasets['test']))
print(datainfo.datasets['train'][0])

print(len(vocab))
print(len(datainfo.vocabs['target']))

model = DPCNN(init_embed=embedding, num_cls=ops.num_classes)

model = DPCNN(init_embed=embedding, num_cls=ops.num_classes,
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
print(model)

# 3. 声明loss,metric,optimizer
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
lr=ops.lr, momentum=0.9, weight_decay=0)
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)

callbacks = []
callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
# callbacks.append
# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
# ops.train_epoch * 0.8 else ops.lr * 0.1))
# )

# callbacks.append(
# FitlogCallback(data=datainfo.datasets, verbose=1)
# )

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(device)

for ds in datainfo.datasets.values():
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
ds.set_input(C.INPUT, C.INPUT_LEN)
ds.set_target(C.TARGET)


# 4.定义train方法
def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=[metrics],
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=num_epochs)
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=[metric],
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=ops.train_epoch, num_workers=4)

print(trainer.train())


if __name__ == "__main__":
train(model, datainfo, loss, metric, optimizer)
print(trainer.train())

+ 11
- 0
reproduction/text_classification/utils/util_init.py View File

@@ -0,0 +1,11 @@
import numpy
import torch
import random


def set_rng_seeds(seed):
random.seed(seed)
numpy.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# print('RNG_SEED {}'.format(seed))

Loading…
Cancel
Save