Browse Source

Merge pull request #171 from QipengGuo/dev0.5.0

[bugfix] hot fix for Star-Transformer on SST
tags/v0.4.10
Xipeng Qiu GitHub 5 years ago
parent
commit
ba2732ac73
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 36 deletions
  1. +5
    -2
      fastNLP/io/data_loader/sst.py
  2. +3
    -3
      fastNLP/models/star_transformer.py
  3. +8
    -5
      fastNLP/modules/encoder/star_transformer.py
  4. +8
    -3
      reproduction/Star_transformer/datasets.py
  5. +2
    -2
      reproduction/Star_transformer/run.sh
  6. +38
    -21
      reproduction/Star_transformer/train.py

+ 5
- 2
fastNLP/io/data_loader/sst.py View File

@@ -1,11 +1,14 @@
from typing import Iterable from typing import Iterable
from nltk import Tree from nltk import Tree
import spacy
from ..base_loader import DataInfo, DataSetLoader from ..base_loader import DataInfo, DataSetLoader
from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ..embed_loader import EmbeddingOption, EmbedLoader from ..embed_loader import EmbeddingOption, EmbedLoader


spacy.prefer_gpu()
sptk = spacy.load('en')


class SSTLoader(DataSetLoader): class SSTLoader(DataSetLoader):
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
@@ -56,8 +59,8 @@ class SSTLoader(DataSetLoader):
def _get_one(data, subtree): def _get_one(data, subtree):
tree = Tree.fromstring(data) tree = Tree.fromstring(data)
if subtree: if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]
return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ]
return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())]


def process(self, def process(self,
paths, paths,


+ 3
- 3
fastNLP/models/star_transformer.py View File

@@ -46,7 +46,7 @@ class StarTransEnc(nn.Module):
super(StarTransEnc, self).__init__() super(StarTransEnc, self).__init__()
self.embedding = get_embeddings(init_embed) self.embedding = get_embeddings(init_embed)
emb_dim = self.embedding.embedding_dim emb_dim = self.embedding.embedding_dim
self.emb_fc = nn.Linear(emb_dim, hidden_size)
#self.emb_fc = nn.Linear(emb_dim, hidden_size)
self.emb_drop = nn.Dropout(emb_dropout) self.emb_drop = nn.Dropout(emb_dropout)
self.encoder = StarTransformer(hidden_size=hidden_size, self.encoder = StarTransformer(hidden_size=hidden_size,
num_layers=num_layers, num_layers=num_layers,
@@ -65,7 +65,7 @@ class StarTransEnc(nn.Module):
[batch, hidden] 全局 relay 节点, 详见论文 [batch, hidden] 全局 relay 节点, 详见论文
""" """
x = self.embedding(x) x = self.embedding(x)
x = self.emb_fc(self.emb_drop(x))
#x = self.emb_fc(self.emb_drop(x))
nodes, relay = self.encoder(x, mask) nodes, relay = self.encoder(x, mask)
return nodes, relay return nodes, relay


@@ -205,7 +205,7 @@ class STSeqCls(nn.Module):
max_len=max_len, max_len=max_len,
emb_dropout=emb_dropout, emb_dropout=emb_dropout,
dropout=dropout) dropout=dropout)
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size)
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout)
def forward(self, words, seq_len): def forward(self, words, seq_len):
""" """


+ 8
- 5
fastNLP/modules/encoder/star_transformer.py View File

@@ -35,11 +35,13 @@ class StarTransformer(nn.Module):
self.iters = num_layers self.iters = num_layers
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)])
self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
self.emb_drop = nn.Dropout(dropout)
self.ring_att = nn.ModuleList( self.ring_att = nn.ModuleList(
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)]) for _ in range(self.iters)])
self.star_att = nn.ModuleList( self.star_att = nn.ModuleList(
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)]) for _ in range(self.iters)])
if max_len is not None: if max_len is not None:
@@ -66,18 +68,19 @@ class StarTransformer(nn.Module):
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
if self.pos_emb:
if self.pos_emb and False:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
embs = embs + P embs = embs + P
embs = norm_func(self.emb_drop, embs)
nodes = embs nodes = embs
relay = embs.mean(2, keepdim=True) relay = embs.mean(2, keepdim=True)
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
r_embs = embs.view(B, H, 1, L) r_embs = embs.view(B, H, 1, L)
for i in range(self.iters): for i in range(self.iters):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
#nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))
nodes = nodes.masked_fill_(ex_mask, 0) nodes = nodes.masked_fill_(ex_mask, 0)


+ 8
- 3
reproduction/Star_transformer/datasets.py View File

@@ -51,13 +51,15 @@ def load_sst(path, files):
for sub in [True, False, False]] for sub in [True, False, False]]
ds_list = [loader.load(os.path.join(path, fn)) ds_list = [loader.load(os.path.join(path, fn))
for fn, loader in zip(files, loaders)] for fn, loader in zip(files, loaders)]
word_v = Vocabulary(min_freq=2)
word_v = Vocabulary(min_freq=0)
tag_v = Vocabulary(unknown=None, padding=None) tag_v = Vocabulary(unknown=None, padding=None)
for ds in ds_list: for ds in ds_list:
ds.apply(lambda x: [w.lower() ds.apply(lambda x: [w.lower()
for w in x['words']], new_field_name='words') for w in x['words']], new_field_name='words')
ds_list[0].drop(lambda x: len(x['words']) < 3)
#ds_list[0].drop(lambda x: len(x['words']) < 3)
update_v(word_v, ds_list[0], 'words') update_v(word_v, ds_list[0], 'words')
update_v(word_v, ds_list[1], 'words')
update_v(word_v, ds_list[2], 'words')
ds_list[0].apply(lambda x: tag_v.add_word( ds_list[0].apply(lambda x: tag_v.add_word(
x['target']), new_field_name=None) x['target']), new_field_name=None)


@@ -152,7 +154,10 @@ class EmbedLoader:
# some words from vocab are missing in pre-trained embedding # some words from vocab are missing in pre-trained embedding
# we normally sample each dimension # we normally sample each dimension
vocab_embed = embedding_matrix[np.where(hit_flags)] vocab_embed = embedding_matrix[np.where(hit_flags)]
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
#sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
# size=(len(vocab) - np.sum(hit_flags), emb_dim))
sampled_vectors = np.random.uniform(-0.01, 0.01,
size=(len(vocab) - np.sum(hit_flags), emb_dim)) size=(len(vocab) - np.sum(hit_flags), emb_dim))

embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors
return embedding_matrix return embedding_matrix

+ 2
- 2
reproduction/Star_transformer/run.sh View File

@@ -1,5 +1,5 @@
#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & #python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 &
#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & #python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 &
#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log &
python -u train.py --task cls --ds sst --mode train --gpu 0 --lr 1e-4 --w_decay 5e-5 --lr_decay 1.0 --drop 0.4 --ep 20 --bsz 64 > sst_cls.log &
#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & #python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log &
python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log &
#python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log &

+ 38
- 21
reproduction/Star_transformer/train.py View File

@@ -1,4 +1,6 @@
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args from util import get_argparser, set_gpu, set_rng_seeds, add_model_args
seed = set_rng_seeds(15360)
print('RNG SEED {}'.format(seed))
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN
import torch.nn as nn import torch.nn as nn
import torch import torch
@@ -7,8 +9,8 @@ import fastNLP as FN
from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls
from fastNLP.core.const import Const as C from fastNLP.core.const import Const as C
import sys import sys
sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/')
#sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/')
pre_dir = '/home/ec2-user/fast_data/'


g_model_select = { g_model_select = {
'pos': STSeqLabel, 'pos': STSeqLabel,
@@ -17,8 +19,8 @@ g_model_select = {
'nli': STNLICls, 'nli': STNLICls,
} }


g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt',
'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'}
g_emb_file_path = {'en': pre_dir + 'glove.840B.300d.txt',
'zh': pre_dir + 'cc.zh.300.vec'}


g_args = None g_args = None
g_model_cfg = None g_model_cfg = None
@@ -53,7 +55,7 @@ def get_conll2012_ner():




def get_sst(): def get_sst():
path = '/remote-home/yfshao/workdir/datasets/SST'
path = pre_dir + 'sst'
files = ['train.txt', 'dev.txt', 'test.txt'] files = ['train.txt', 'dev.txt', 'test.txt']
return load_sst(path, files) return load_sst(path, files)


@@ -94,6 +96,7 @@ class MyCallback(FN.core.callback.Callback):
nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0)


def on_step_end(self): def on_step_end(self):
return
warm_steps = 6000 warm_steps = 6000
# learning rate warm-up & decay # learning rate warm-up & decay
if self.step <= warm_steps: if self.step <= warm_steps:
@@ -108,12 +111,11 @@ class MyCallback(FN.core.callback.Callback):




def train(): def train():
seed = set_rng_seeds(1234)
print('RNG SEED {}'.format(seed))
print('loading data') print('loading data')
ds_list, word_v, tag_v = g_datasets['{}-{}'.format( ds_list, word_v, tag_v = g_datasets['{}-{}'.format(
g_args.ds, g_args.task)]() g_args.ds, g_args.task)]()
print(ds_list[0][:2]) print(ds_list[0][:2])
print(len(ds_list[0]), len(ds_list[1]), len(ds_list[2]))
embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en')
g_model_cfg['num_cls'] = len(tag_v) g_model_cfg['num_cls'] = len(tag_v)
print(g_model_cfg) print(g_model_cfg)
@@ -123,11 +125,14 @@ def train():
def init_model(model): def init_model(model):
for p in model.parameters(): for p in model.parameters():
if p.size(0) != len(word_v): if p.size(0) != len(word_v):
nn.init.normal_(p, 0.0, 0.05)
if len(p.size())<2:
nn.init.constant_(p, 0.0)
else:
nn.init.normal_(p, 0.0, 0.05)
init_model(model) init_model(model)
train_data = ds_list[0] train_data = ds_list[0]
dev_data = ds_list[2]
test_data = ds_list[1]
dev_data = ds_list[1]
test_data = ds_list[2]
print(tag_v.word2idx) print(tag_v.word2idx)


if g_args.task in ['pos', 'ner']: if g_args.task in ['pos', 'ner']:
@@ -145,14 +150,26 @@ def train():
} }
metric_key, metric = metrics[g_args.task] metric_key, metric = metrics[g_args.task]
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
ex_param = [x for x in model.parameters(
) if x.requires_grad and x.size(0) != len(word_v)]
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1},
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ]
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss,
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric,
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False,
device=device, callbacks=[MyCallback()])

params = [(x,y) for x,y in list(model.named_parameters()) if y.requires_grad and y.size(0) != len(word_v)]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
print([n for n,p in params])
optim_cfg = [
#{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1},
{'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 1.0*g_args.w_decay},
{'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 0.0*g_args.w_decay}
]

print(model)
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data,
loss=loss, metrics=metric, metric_key=metric_key,
optimizer=torch.optim.Adam(optim_cfg),
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=100, validate_every=1000,
device=device,
use_tqdm=False, prefetch=False,
save_path=g_args.log,
sampler=FN.BucketSampler(100, g_args.bsz, C.INPUT_LEN),
callbacks=[MyCallback()])


trainer.train() trainer.train()
tester = FN.Tester(data=test_data, model=model, metrics=metric, tester = FN.Tester(data=test_data, model=model, metrics=metric,
@@ -195,12 +212,12 @@ def main():
'init_embed': (None, 300), 'init_embed': (None, 300),
'num_cls': None, 'num_cls': None,
'hidden_size': g_args.hidden, 'hidden_size': g_args.hidden,
'num_layers': 4,
'num_layers': 2,
'num_head': g_args.nhead, 'num_head': g_args.nhead,
'head_dim': g_args.hdim, 'head_dim': g_args.hdim,
'max_len': MAX_LEN, 'max_len': MAX_LEN,
'cls_hidden_size': 600,
'emb_dropout': 0.3,
'cls_hidden_size': 200,
'emb_dropout': g_args.drop,
'dropout': g_args.drop, 'dropout': g_args.drop,
} }
run_select[g_args.mode.lower()]() run_select[g_args.mode.lower()]()


Loading…
Cancel
Save