Browse Source

fix and improve star_trans on SST

tags/v0.4.10
EC2 Default User 6 years ago
parent
commit
406e1f7fab
6 changed files with 64 additions and 35 deletions
  1. +5
    -2
      fastNLP/io/data_loader/sst.py
  2. +3
    -3
      fastNLP/models/star_transformer.py
  3. +8
    -5
      fastNLP/modules/encoder/star_transformer.py
  4. +8
    -3
      reproduction/Star_transformer/datasets.py
  5. +2
    -2
      reproduction/Star_transformer/run.sh
  6. +38
    -20
      reproduction/Star_transformer/train.py

+ 5
- 2
fastNLP/io/data_loader/sst.py View File

@@ -1,11 +1,14 @@
from typing import Iterable from typing import Iterable
from nltk import Tree from nltk import Tree
import spacy
from ..base_loader import DataInfo, DataSetLoader from ..base_loader import DataInfo, DataSetLoader
from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ..embed_loader import EmbeddingOption, EmbedLoader from ..embed_loader import EmbeddingOption, EmbedLoader


spacy.prefer_gpu()
sptk = spacy.load('en')


class SSTLoader(DataSetLoader): class SSTLoader(DataSetLoader):
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
@@ -56,8 +59,8 @@ class SSTLoader(DataSetLoader):
def _get_one(data, subtree): def _get_one(data, subtree):
tree = Tree.fromstring(data) tree = Tree.fromstring(data)
if subtree: if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]
return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ]
return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())]


def process(self, def process(self,
paths, paths,


+ 3
- 3
fastNLP/models/star_transformer.py View File

@@ -46,7 +46,7 @@ class StarTransEnc(nn.Module):
super(StarTransEnc, self).__init__() super(StarTransEnc, self).__init__()
self.embedding = get_embeddings(init_embed) self.embedding = get_embeddings(init_embed)
emb_dim = self.embedding.embedding_dim emb_dim = self.embedding.embedding_dim
self.emb_fc = nn.Linear(emb_dim, hidden_size)
#self.emb_fc = nn.Linear(emb_dim, hidden_size)
self.emb_drop = nn.Dropout(emb_dropout) self.emb_drop = nn.Dropout(emb_dropout)
self.encoder = StarTransformer(hidden_size=hidden_size, self.encoder = StarTransformer(hidden_size=hidden_size,
num_layers=num_layers, num_layers=num_layers,
@@ -65,7 +65,7 @@ class StarTransEnc(nn.Module):
[batch, hidden] 全局 relay 节点, 详见论文 [batch, hidden] 全局 relay 节点, 详见论文
""" """
x = self.embedding(x) x = self.embedding(x)
x = self.emb_fc(self.emb_drop(x))
#x = self.emb_fc(self.emb_drop(x))
nodes, relay = self.encoder(x, mask) nodes, relay = self.encoder(x, mask)
return nodes, relay return nodes, relay


@@ -205,7 +205,7 @@ class STSeqCls(nn.Module):
max_len=max_len, max_len=max_len,
emb_dropout=emb_dropout, emb_dropout=emb_dropout,
dropout=dropout) dropout=dropout)
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size)
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout)
def forward(self, words, seq_len): def forward(self, words, seq_len):
""" """


+ 8
- 5
fastNLP/modules/encoder/star_transformer.py View File

@@ -35,11 +35,13 @@ class StarTransformer(nn.Module):
self.iters = num_layers self.iters = num_layers
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)])
self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
self.emb_drop = nn.Dropout(dropout)
self.ring_att = nn.ModuleList( self.ring_att = nn.ModuleList(
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)]) for _ in range(self.iters)])
self.star_att = nn.ModuleList( self.star_att = nn.ModuleList(
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)]) for _ in range(self.iters)])
if max_len is not None: if max_len is not None:
@@ -66,18 +68,19 @@ class StarTransformer(nn.Module):
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
if self.pos_emb:
if self.pos_emb and False:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
embs = embs + P embs = embs + P
embs = norm_func(self.emb_drop, embs)
nodes = embs nodes = embs
relay = embs.mean(2, keepdim=True) relay = embs.mean(2, keepdim=True)
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
r_embs = embs.view(B, H, 1, L) r_embs = embs.view(B, H, 1, L)
for i in range(self.iters): for i in range(self.iters):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
#nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))
nodes = nodes.masked_fill_(ex_mask, 0) nodes = nodes.masked_fill_(ex_mask, 0)


+ 8
- 3
reproduction/Star_transformer/datasets.py View File

@@ -50,13 +50,15 @@ def load_sst(path, files):
for sub in [True, False, False]] for sub in [True, False, False]]
ds_list = [loader.load(os.path.join(path, fn)) ds_list = [loader.load(os.path.join(path, fn))
for fn, loader in zip(files, loaders)] for fn, loader in zip(files, loaders)]
word_v = Vocabulary(min_freq=2)
word_v = Vocabulary(min_freq=0)
tag_v = Vocabulary(unknown=None, padding=None) tag_v = Vocabulary(unknown=None, padding=None)
for ds in ds_list: for ds in ds_list:
ds.apply(lambda x: [w.lower() ds.apply(lambda x: [w.lower()
for w in x['words']], new_field_name='words') for w in x['words']], new_field_name='words')
ds_list[0].drop(lambda x: len(x['words']) < 3)
#ds_list[0].drop(lambda x: len(x['words']) < 3)
update_v(word_v, ds_list[0], 'words') update_v(word_v, ds_list[0], 'words')
update_v(word_v, ds_list[1], 'words')
update_v(word_v, ds_list[2], 'words')
ds_list[0].apply(lambda x: tag_v.add_word( ds_list[0].apply(lambda x: tag_v.add_word(
x['target']), new_field_name=None) x['target']), new_field_name=None)


@@ -151,7 +153,10 @@ class EmbedLoader:
# some words from vocab are missing in pre-trained embedding # some words from vocab are missing in pre-trained embedding
# we normally sample each dimension # we normally sample each dimension
vocab_embed = embedding_matrix[np.where(hit_flags)] vocab_embed = embedding_matrix[np.where(hit_flags)]
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
#sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
# size=(len(vocab) - np.sum(hit_flags), emb_dim))
sampled_vectors = np.random.uniform(-0.01, 0.01,
size=(len(vocab) - np.sum(hit_flags), emb_dim)) size=(len(vocab) - np.sum(hit_flags), emb_dim))

embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors
return embedding_matrix return embedding_matrix

+ 2
- 2
reproduction/Star_transformer/run.sh View File

@@ -1,5 +1,5 @@
#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 & #python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 &
#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 & #python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 &
#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log &
python -u train.py --task cls --ds sst --mode train --gpu 0 --lr 1e-4 --w_decay 5e-5 --lr_decay 1.0 --drop 0.4 --ep 20 --bsz 64 > sst_cls.log &
#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log & #python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log &
python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log &
#python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log &

+ 38
- 20
reproduction/Star_transformer/train.py View File

@@ -7,8 +7,8 @@ import fastNLP as FN
from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls
from fastNLP.core.const import Const as C from fastNLP.core.const import Const as C
import sys import sys
sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/')
#sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/')
pre_dir = '/home/ec2-user/fast_data/'


g_model_select = { g_model_select = {
'pos': STSeqLabel, 'pos': STSeqLabel,
@@ -17,8 +17,8 @@ g_model_select = {
'nli': STNLICls, 'nli': STNLICls,
} }


g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt',
'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'}
g_emb_file_path = {'en': pre_dir + 'glove.840B.300d.txt',
'zh': pre_dir + 'cc.zh.300.vec'}


g_args = None g_args = None
g_model_cfg = None g_model_cfg = None
@@ -53,7 +53,7 @@ def get_conll2012_ner():




def get_sst(): def get_sst():
path = '/remote-home/yfshao/workdir/datasets/SST'
path = pre_dir + 'sst'
files = ['train.txt', 'dev.txt', 'test.txt'] files = ['train.txt', 'dev.txt', 'test.txt']
return load_sst(path, files) return load_sst(path, files)


@@ -94,6 +94,7 @@ class MyCallback(FN.core.callback.Callback):
nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0) nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0)


def on_step_end(self): def on_step_end(self):
return
warm_steps = 6000 warm_steps = 6000
# learning rate warm-up & decay # learning rate warm-up & decay
if self.step <= warm_steps: if self.step <= warm_steps:
@@ -108,12 +109,14 @@ class MyCallback(FN.core.callback.Callback):




def train(): def train():
seed = set_rng_seeds(1234)
#seed = set_rng_seeds(1234)
seed = set_rng_seeds(np.random.randint(65536))
print('RNG SEED {}'.format(seed)) print('RNG SEED {}'.format(seed))
print('loading data') print('loading data')
ds_list, word_v, tag_v = g_datasets['{}-{}'.format( ds_list, word_v, tag_v = g_datasets['{}-{}'.format(
g_args.ds, g_args.task)]() g_args.ds, g_args.task)]()
print(ds_list[0][:2]) print(ds_list[0][:2])
print(len(ds_list[0]), len(ds_list[1]), len(ds_list[2]))
embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en') embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en')
g_model_cfg['num_cls'] = len(tag_v) g_model_cfg['num_cls'] = len(tag_v)
print(g_model_cfg) print(g_model_cfg)
@@ -123,11 +126,14 @@ def train():
def init_model(model): def init_model(model):
for p in model.parameters(): for p in model.parameters():
if p.size(0) != len(word_v): if p.size(0) != len(word_v):
nn.init.normal_(p, 0.0, 0.05)
if len(p.size())<2:
nn.init.constant_(p, 0.0)
else:
nn.init.normal_(p, 0.0, 0.05)
init_model(model) init_model(model)
train_data = ds_list[0] train_data = ds_list[0]
dev_data = ds_list[2]
test_data = ds_list[1]
dev_data = ds_list[1]
test_data = ds_list[2]
print(tag_v.word2idx) print(tag_v.word2idx)


if g_args.task in ['pos', 'ner']: if g_args.task in ['pos', 'ner']:
@@ -145,14 +151,26 @@ def train():
} }
metric_key, metric = metrics[g_args.task] metric_key, metric = metrics[g_args.task]
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
ex_param = [x for x in model.parameters(
) if x.requires_grad and x.size(0) != len(word_v)]
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1},
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ]
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss,
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric,
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False,
device=device, callbacks=[MyCallback()])

params = [(x,y) for x,y in list(model.named_parameters()) if y.requires_grad and y.size(0) != len(word_v)]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
print([n for n,p in params])
optim_cfg = [
#{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1},
{'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 1.0*g_args.w_decay},
{'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'lr': g_args.lr, 'weight_decay': 0.0*g_args.w_decay}
]

print(model)
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data,
loss=loss, metrics=metric, metric_key=metric_key,
optimizer=torch.optim.Adam(optim_cfg),
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=100, validate_every=1000,
device=device,
use_tqdm=False, prefetch=False,
save_path=g_args.log,
sampler=FN.BucketSampler(100, g_args.bsz, C.INPUT_LEN),
callbacks=[MyCallback()])


trainer.train() trainer.train()
tester = FN.Tester(data=test_data, model=model, metrics=metric, tester = FN.Tester(data=test_data, model=model, metrics=metric,
@@ -195,12 +213,12 @@ def main():
'init_embed': (None, 300), 'init_embed': (None, 300),
'num_cls': None, 'num_cls': None,
'hidden_size': g_args.hidden, 'hidden_size': g_args.hidden,
'num_layers': 4,
'num_layers': 2,
'num_head': g_args.nhead, 'num_head': g_args.nhead,
'head_dim': g_args.hdim, 'head_dim': g_args.hdim,
'max_len': MAX_LEN, 'max_len': MAX_LEN,
'cls_hidden_size': 600,
'emb_dropout': 0.3,
'cls_hidden_size': 200,
'emb_dropout': g_args.drop,
'dropout': g_args.drop, 'dropout': g_args.drop,
} }
run_select[g_args.mode.lower()]() run_select[g_args.mode.lower()]()


Loading…
Cancel
Save