Browse Source

本地暂存

tags/v0.4.10
yh_cc 6 years ago
parent
commit
95aabc941d
1 changed files with 86 additions and 12 deletions
  1. +86
    -12
      reproduction/Chinese_word_segmentation/models/cws_transformer.py

+ 86
- 12
reproduction/Chinese_word_segmentation/models/cws_transformer.py View File

@@ -8,7 +8,8 @@


from torch import nn from torch import nn
import torch import torch
from fastNLP.modules.encoder.transformer import TransformerEncoder
# from fastNLP.modules.encoder.transformer import TransformerEncoder
from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask
from fastNLP.modules.decoder.CRF import allowed_transitions from fastNLP.modules.decoder.CRF import allowed_transitions


@@ -27,11 +28,83 @@ class TransformerCWS(nn.Module):


self.fc1 = nn.Linear(input_size, hidden_size) self.fc1 = nn.Linear(input_size, hidden_size)


value_size = hidden_size//num_heads
self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size,
key_size=value_size,
value_size=value_size, num_head=num_heads)
# value_size = hidden_size//num_heads
# self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size,
# key_size=value_size,
# value_size=value_size, num_head=num_heads)
self.transformer = TransformerEncoder(num_layers=num_layers, model_size=hidden_size, num_heads=num_heads,
hidden_size=hidden_size)
self.fc2 = nn.Linear(hidden_size, tag_size)

allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes')
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False,
allowed_transitions=allowed_trans)

def forward(self, chars, target, seq_lens, bigrams=None):
masks = seq_len_to_byte_mask(seq_lens)
x = self.embedding(chars)
batch_size = x.size(0)
length = x.size(1)
if hasattr(self, 'bigram_embedding'):
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1)
self.drop(x)
x = self.fc1(x)
feats = self.transformer(x, masks)
feats = self.fc2(feats)
losses = self.crf(feats, target, masks.float())

pred_dict = {}
pred_dict['seq_lens'] = seq_lens
pred_dict['loss'] = torch.mean(losses)

return pred_dict


def predict(self, chars, seq_lens, bigrams=None):
masks = seq_len_to_byte_mask(seq_lens)

x = self.embedding(chars)
batch_size = x.size(0)
length = x.size(1)
if hasattr(self, 'bigram_embedding'):
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1)
self.drop(x)
x = self.fc1(x)
feats = self.transformer(x, masks)
feats = self.fc2(feats)

probs = self.crf.viterbi_decode(feats, masks, get_score=False)

return {'pred': probs, 'seq_lens':seq_lens}


from reproduction.Chinese_word_segmentation.models.dilated_transformer import TransformerDilateEncoder

class TransformerDilatedCWS(nn.Module):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
embed_drop_p=0.3, hidden_size=200, kernel_size=3, dilate='none',
num_layers=1, num_heads=8, tag_size=4,
relative_pos_embed_dim=0):
super().__init__()

self.embedding = nn.Embedding(vocab_num, embed_dim)
input_size = embed_dim
if bigram_vocab_num:
self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim)
input_size += num_bigram_per_char*bigram_embed_dim

self.drop = nn.Dropout(embed_drop_p, inplace=True)

self.fc1 = nn.Linear(input_size, hidden_size)

# value_size = hidden_size//num_heads
# self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size,
# key_size=value_size,
# value_size=value_size, num_head=num_heads)
self.transformer = TransformerDilateEncoder(num_layers=num_layers, model_size=hidden_size, num_heads=num_heads,
hidden_size=hidden_size, kernel_size=kernel_size, dilate=dilate,
relative_pos_embed_dim=relative_pos_embed_dim)
self.fc2 = nn.Linear(hidden_size, tag_size) self.fc2 = nn.Linear(hidden_size, tag_size)


allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes')
@@ -39,7 +112,7 @@ class TransformerCWS(nn.Module):
allowed_transitions=allowed_trans) allowed_transitions=allowed_trans)


def forward(self, chars, target, seq_lens, bigrams=None): def forward(self, chars, target, seq_lens, bigrams=None):
masks = seq_len_to_byte_mask(seq_lens).float()
masks = seq_len_to_byte_mask(seq_lens)
x = self.embedding(chars) x = self.embedding(chars)
batch_size = x.size(0) batch_size = x.size(0)
length = x.size(1) length = x.size(1)
@@ -59,7 +132,7 @@ class TransformerCWS(nn.Module):
return pred_dict return pred_dict


def predict(self, chars, seq_lens, bigrams=None): def predict(self, chars, seq_lens, bigrams=None):
masks = seq_len_to_byte_mask(seq_lens).float()
masks = seq_len_to_byte_mask(seq_lens)


x = self.embedding(chars) x = self.embedding(chars)
batch_size = x.size(0) batch_size = x.size(0)
@@ -77,6 +150,7 @@ class TransformerCWS(nn.Module):
return {'pred': probs, 'seq_lens':seq_lens} return {'pred': probs, 'seq_lens':seq_lens}





class NoamOpt(torch.optim.Optimizer): class NoamOpt(torch.optim.Optimizer):
"Optim wrapper that implements rate." "Optim wrapper that implements rate."


@@ -107,10 +181,7 @@ class NoamOpt(torch.optim.Optimizer):
(self.model_size ** (-0.5) * (self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5))) min(step ** (-0.5), step * self.warmup ** (-1.5)))



if __name__ == '__main__':


def TransformerCWS_test():
transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8, transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8,
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4) hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4)
chars = torch.randint(10, size=(4, 7)).long() chars = torch.randint(10, size=(4, 7)).long()
@@ -122,4 +193,7 @@ if __name__ == '__main__':


optimizer = torch.optim.Adam(transformer.parameters()) optimizer = torch.optim.Adam(transformer.parameters())


opt = NoamOpt(10 ,1, 400, optimizer)
opt = NoamOpt(10 ,1, 400, optimizer)

if __name__ == '__main__':
TransformerCWS_test()

Loading…
Cancel
Save