|
@@ -8,7 +8,8 @@ |
|
|
|
|
|
|
|
|
from torch import nn |
|
|
from torch import nn |
|
|
import torch |
|
|
import torch |
|
|
from fastNLP.modules.encoder.transformer import TransformerEncoder |
|
|
|
|
|
|
|
|
# from fastNLP.modules.encoder.transformer import TransformerEncoder |
|
|
|
|
|
from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder |
|
|
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask |
|
|
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask |
|
|
from fastNLP.modules.decoder.CRF import allowed_transitions |
|
|
from fastNLP.modules.decoder.CRF import allowed_transitions |
|
|
|
|
|
|
|
@@ -27,11 +28,83 @@ class TransformerCWS(nn.Module): |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(input_size, hidden_size) |
|
|
self.fc1 = nn.Linear(input_size, hidden_size) |
|
|
|
|
|
|
|
|
value_size = hidden_size//num_heads |
|
|
|
|
|
self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, |
|
|
|
|
|
key_size=value_size, |
|
|
|
|
|
value_size=value_size, num_head=num_heads) |
|
|
|
|
|
|
|
|
# value_size = hidden_size//num_heads |
|
|
|
|
|
# self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, |
|
|
|
|
|
# key_size=value_size, |
|
|
|
|
|
# value_size=value_size, num_head=num_heads) |
|
|
|
|
|
self.transformer = TransformerEncoder(num_layers=num_layers, model_size=hidden_size, num_heads=num_heads, |
|
|
|
|
|
hidden_size=hidden_size) |
|
|
|
|
|
self.fc2 = nn.Linear(hidden_size, tag_size) |
|
|
|
|
|
|
|
|
|
|
|
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') |
|
|
|
|
|
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, |
|
|
|
|
|
allowed_transitions=allowed_trans) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, chars, target, seq_lens, bigrams=None): |
|
|
|
|
|
masks = seq_len_to_byte_mask(seq_lens) |
|
|
|
|
|
x = self.embedding(chars) |
|
|
|
|
|
batch_size = x.size(0) |
|
|
|
|
|
length = x.size(1) |
|
|
|
|
|
if hasattr(self, 'bigram_embedding'): |
|
|
|
|
|
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size |
|
|
|
|
|
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1) |
|
|
|
|
|
self.drop(x) |
|
|
|
|
|
x = self.fc1(x) |
|
|
|
|
|
feats = self.transformer(x, masks) |
|
|
|
|
|
feats = self.fc2(feats) |
|
|
|
|
|
losses = self.crf(feats, target, masks.float()) |
|
|
|
|
|
|
|
|
|
|
|
pred_dict = {} |
|
|
|
|
|
pred_dict['seq_lens'] = seq_lens |
|
|
|
|
|
pred_dict['loss'] = torch.mean(losses) |
|
|
|
|
|
|
|
|
|
|
|
return pred_dict |
|
|
|
|
|
|
|
|
|
|
|
def predict(self, chars, seq_lens, bigrams=None): |
|
|
|
|
|
masks = seq_len_to_byte_mask(seq_lens) |
|
|
|
|
|
|
|
|
|
|
|
x = self.embedding(chars) |
|
|
|
|
|
batch_size = x.size(0) |
|
|
|
|
|
length = x.size(1) |
|
|
|
|
|
if hasattr(self, 'bigram_embedding'): |
|
|
|
|
|
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size |
|
|
|
|
|
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1) |
|
|
|
|
|
self.drop(x) |
|
|
|
|
|
x = self.fc1(x) |
|
|
|
|
|
feats = self.transformer(x, masks) |
|
|
|
|
|
feats = self.fc2(feats) |
|
|
|
|
|
|
|
|
|
|
|
probs = self.crf.viterbi_decode(feats, masks, get_score=False) |
|
|
|
|
|
|
|
|
|
|
|
return {'pred': probs, 'seq_lens':seq_lens} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from reproduction.Chinese_word_segmentation.models.dilated_transformer import TransformerDilateEncoder |
|
|
|
|
|
|
|
|
|
|
|
class TransformerDilatedCWS(nn.Module): |
|
|
|
|
|
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, |
|
|
|
|
|
embed_drop_p=0.3, hidden_size=200, kernel_size=3, dilate='none', |
|
|
|
|
|
num_layers=1, num_heads=8, tag_size=4, |
|
|
|
|
|
relative_pos_embed_dim=0): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(vocab_num, embed_dim) |
|
|
|
|
|
input_size = embed_dim |
|
|
|
|
|
if bigram_vocab_num: |
|
|
|
|
|
self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim) |
|
|
|
|
|
input_size += num_bigram_per_char*bigram_embed_dim |
|
|
|
|
|
|
|
|
|
|
|
self.drop = nn.Dropout(embed_drop_p, inplace=True) |
|
|
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(input_size, hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
# value_size = hidden_size//num_heads |
|
|
|
|
|
# self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, |
|
|
|
|
|
# key_size=value_size, |
|
|
|
|
|
# value_size=value_size, num_head=num_heads) |
|
|
|
|
|
self.transformer = TransformerDilateEncoder(num_layers=num_layers, model_size=hidden_size, num_heads=num_heads, |
|
|
|
|
|
hidden_size=hidden_size, kernel_size=kernel_size, dilate=dilate, |
|
|
|
|
|
relative_pos_embed_dim=relative_pos_embed_dim) |
|
|
self.fc2 = nn.Linear(hidden_size, tag_size) |
|
|
self.fc2 = nn.Linear(hidden_size, tag_size) |
|
|
|
|
|
|
|
|
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') |
|
|
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') |
|
@@ -39,7 +112,7 @@ class TransformerCWS(nn.Module): |
|
|
allowed_transitions=allowed_trans) |
|
|
allowed_transitions=allowed_trans) |
|
|
|
|
|
|
|
|
def forward(self, chars, target, seq_lens, bigrams=None): |
|
|
def forward(self, chars, target, seq_lens, bigrams=None): |
|
|
masks = seq_len_to_byte_mask(seq_lens).float() |
|
|
|
|
|
|
|
|
masks = seq_len_to_byte_mask(seq_lens) |
|
|
x = self.embedding(chars) |
|
|
x = self.embedding(chars) |
|
|
batch_size = x.size(0) |
|
|
batch_size = x.size(0) |
|
|
length = x.size(1) |
|
|
length = x.size(1) |
|
@@ -59,7 +132,7 @@ class TransformerCWS(nn.Module): |
|
|
return pred_dict |
|
|
return pred_dict |
|
|
|
|
|
|
|
|
def predict(self, chars, seq_lens, bigrams=None): |
|
|
def predict(self, chars, seq_lens, bigrams=None): |
|
|
masks = seq_len_to_byte_mask(seq_lens).float() |
|
|
|
|
|
|
|
|
masks = seq_len_to_byte_mask(seq_lens) |
|
|
|
|
|
|
|
|
x = self.embedding(chars) |
|
|
x = self.embedding(chars) |
|
|
batch_size = x.size(0) |
|
|
batch_size = x.size(0) |
|
@@ -77,6 +150,7 @@ class TransformerCWS(nn.Module): |
|
|
return {'pred': probs, 'seq_lens':seq_lens} |
|
|
return {'pred': probs, 'seq_lens':seq_lens} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NoamOpt(torch.optim.Optimizer): |
|
|
class NoamOpt(torch.optim.Optimizer): |
|
|
"Optim wrapper that implements rate." |
|
|
"Optim wrapper that implements rate." |
|
|
|
|
|
|
|
@@ -107,10 +181,7 @@ class NoamOpt(torch.optim.Optimizer): |
|
|
(self.model_size ** (-0.5) * |
|
|
(self.model_size ** (-0.5) * |
|
|
min(step ** (-0.5), step * self.warmup ** (-1.5))) |
|
|
min(step ** (-0.5), step * self.warmup ** (-1.5))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def TransformerCWS_test(): |
|
|
transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8, |
|
|
transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8, |
|
|
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4) |
|
|
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4) |
|
|
chars = torch.randint(10, size=(4, 7)).long() |
|
|
chars = torch.randint(10, size=(4, 7)).long() |
|
@@ -122,4 +193,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(transformer.parameters()) |
|
|
optimizer = torch.optim.Adam(transformer.parameters()) |
|
|
|
|
|
|
|
|
opt = NoamOpt(10 ,1, 400, optimizer) |
|
|
|
|
|
|
|
|
opt = NoamOpt(10 ,1, 400, optimizer) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
TransformerCWS_test() |