From 95aabc941d9283638ce739926f436a3b79c7615b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 22 Apr 2019 17:26:55 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=AC=E5=9C=B0=E6=9A=82=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models/cws_transformer.py | 98 ++++++++++++++++--- 1 file changed, 86 insertions(+), 12 deletions(-) diff --git a/reproduction/Chinese_word_segmentation/models/cws_transformer.py b/reproduction/Chinese_word_segmentation/models/cws_transformer.py index 736edade..375eaa14 100644 --- a/reproduction/Chinese_word_segmentation/models/cws_transformer.py +++ b/reproduction/Chinese_word_segmentation/models/cws_transformer.py @@ -8,7 +8,8 @@ from torch import nn import torch -from fastNLP.modules.encoder.transformer import TransformerEncoder +# from fastNLP.modules.encoder.transformer import TransformerEncoder +from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask from fastNLP.modules.decoder.CRF import allowed_transitions @@ -27,11 +28,83 @@ class TransformerCWS(nn.Module): self.fc1 = nn.Linear(input_size, hidden_size) - value_size = hidden_size//num_heads - self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, - key_size=value_size, - value_size=value_size, num_head=num_heads) + # value_size = hidden_size//num_heads + # self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, + # key_size=value_size, + # value_size=value_size, num_head=num_heads) + self.transformer = TransformerEncoder(num_layers=num_layers, model_size=hidden_size, num_heads=num_heads, + hidden_size=hidden_size) + self.fc2 = nn.Linear(hidden_size, tag_size) + + allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') + self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, + allowed_transitions=allowed_trans) + + def forward(self, chars, target, seq_lens, bigrams=None): + masks = seq_len_to_byte_mask(seq_lens) + x = self.embedding(chars) + batch_size = x.size(0) + length = x.size(1) + if hasattr(self, 'bigram_embedding'): + bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size + x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1) + self.drop(x) + x = self.fc1(x) + feats = self.transformer(x, masks) + feats = self.fc2(feats) + losses = self.crf(feats, target, masks.float()) + + pred_dict = {} + pred_dict['seq_lens'] = seq_lens + pred_dict['loss'] = torch.mean(losses) + + return pred_dict + def predict(self, chars, seq_lens, bigrams=None): + masks = seq_len_to_byte_mask(seq_lens) + + x = self.embedding(chars) + batch_size = x.size(0) + length = x.size(1) + if hasattr(self, 'bigram_embedding'): + bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size + x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1) + self.drop(x) + x = self.fc1(x) + feats = self.transformer(x, masks) + feats = self.fc2(feats) + + probs = self.crf.viterbi_decode(feats, masks, get_score=False) + + return {'pred': probs, 'seq_lens':seq_lens} + + +from reproduction.Chinese_word_segmentation.models.dilated_transformer import TransformerDilateEncoder + +class TransformerDilatedCWS(nn.Module): + def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, + embed_drop_p=0.3, hidden_size=200, kernel_size=3, dilate='none', + num_layers=1, num_heads=8, tag_size=4, + relative_pos_embed_dim=0): + super().__init__() + + self.embedding = nn.Embedding(vocab_num, embed_dim) + input_size = embed_dim + if bigram_vocab_num: + self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim) + input_size += num_bigram_per_char*bigram_embed_dim + + self.drop = nn.Dropout(embed_drop_p, inplace=True) + + self.fc1 = nn.Linear(input_size, hidden_size) + + # value_size = hidden_size//num_heads + # self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, + # key_size=value_size, + # value_size=value_size, num_head=num_heads) + self.transformer = TransformerDilateEncoder(num_layers=num_layers, model_size=hidden_size, num_heads=num_heads, + hidden_size=hidden_size, kernel_size=kernel_size, dilate=dilate, + relative_pos_embed_dim=relative_pos_embed_dim) self.fc2 = nn.Linear(hidden_size, tag_size) allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') @@ -39,7 +112,7 @@ class TransformerCWS(nn.Module): allowed_transitions=allowed_trans) def forward(self, chars, target, seq_lens, bigrams=None): - masks = seq_len_to_byte_mask(seq_lens).float() + masks = seq_len_to_byte_mask(seq_lens) x = self.embedding(chars) batch_size = x.size(0) length = x.size(1) @@ -59,7 +132,7 @@ class TransformerCWS(nn.Module): return pred_dict def predict(self, chars, seq_lens, bigrams=None): - masks = seq_len_to_byte_mask(seq_lens).float() + masks = seq_len_to_byte_mask(seq_lens) x = self.embedding(chars) batch_size = x.size(0) @@ -77,6 +150,7 @@ class TransformerCWS(nn.Module): return {'pred': probs, 'seq_lens':seq_lens} + class NoamOpt(torch.optim.Optimizer): "Optim wrapper that implements rate." @@ -107,10 +181,7 @@ class NoamOpt(torch.optim.Optimizer): (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) - -if __name__ == '__main__': - - +def TransformerCWS_test(): transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8, hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4) chars = torch.randint(10, size=(4, 7)).long() @@ -122,4 +193,7 @@ if __name__ == '__main__': optimizer = torch.optim.Adam(transformer.parameters()) - opt = NoamOpt(10 ,1, 400, optimizer) \ No newline at end of file + opt = NoamOpt(10 ,1, 400, optimizer) + +if __name__ == '__main__': + TransformerCWS_test()