@@ -51,7 +51,9 @@ class Trainer(object): | |||||
:param Optimizer optimizer: an optimizer object | :param Optimizer optimizer: an optimizer object | ||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\ | :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\ | ||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | ||||
it will raise error if some field are not used. | |||||
it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是否能够 | |||||
运行,但是这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况,;(2) | |||||
模型中存在累加前向计算次数的,可能会多计算几次。建议将check_code_level设置为-1 | |||||
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | ||||
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | ||||
smaller, add "-" in front of the string. For example:: | smaller, add "-" in front of the string. For example:: | ||||
@@ -46,6 +46,21 @@ class DotAtte(nn.Module): | |||||
class MultiHeadAtte(nn.Module): | class MultiHeadAtte(nn.Module): | ||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | def __init__(self, input_size, output_size, key_size, value_size, num_atte): | ||||
""" | |||||
实现的是以下内容 | |||||
QW1: (batch_size, seq_len, input_size) * (input_size, key_size) | |||||
KW2: (batch_size, seq_len, input_size) * (input_size, key_size) | |||||
VW3: (batch_size, seq_len, input_size) * (input_size, value_size) | |||||
softmax(QK^T/sqrt(scale))*V: (batch_size, seq_len, value_size) 多个head(num_atten指定)的结果为 | |||||
(batch_size, seq_len, value_size*num_atte) | |||||
最终结果将上式过一个(value_size*num_atte, output_size)的线性层,output为(batch_size, seq_len, output_size) | |||||
:param input_size: int, 输入的维度 | |||||
:param output_size: int, 输出特征的维度 | |||||
:param key_size: int, query和key映射到该维度 | |||||
:param value_size: int, value映射到该维度 | |||||
:param num_atte: | |||||
""" | |||||
super(MultiHeadAtte, self).__init__() | super(MultiHeadAtte, self).__init__() | ||||
self.in_linear = nn.ModuleList() | self.in_linear = nn.ModuleList() | ||||
for i in range(num_atte * 3): | for i in range(num_atte * 3): | ||||
@@ -0,0 +1,70 @@ | |||||
""" | |||||
使用transformer作为分词的encoder端 | |||||
""" | |||||
from torch import nn | |||||
import torch | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask | |||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
class TransformerCWS(nn.Module): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4): | |||||
super().__init__() | |||||
self.embedding = nn.Embedding(vocab_num, embed_dim) | |||||
input_size = embed_dim | |||||
if bigram_vocab_num: | |||||
self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim) | |||||
input_size += num_bigram_per_char*bigram_embed_dim | |||||
self.drop = nn.Dropout(embed_drop_p, inplace=True) | |||||
self.fc1 = nn.Linear(input_size, hidden_size) | |||||
value_size = hidden_size//num_heads | |||||
self.transformer = TransformerEncoder(num_layers, input_size=input_size, output_size=hidden_size, | |||||
key_size=value_size, value_size=value_size, num_atte=num_heads) | |||||
self.fc2 = nn.Linear(hidden_size, tag_size) | |||||
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') | |||||
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, | |||||
allowed_transitions=allowed_trans) | |||||
def forward(self, chars, target, seq_lens, bigrams=None): | |||||
seq_lens = seq_lens | |||||
masks = seq_len_to_byte_mask(seq_lens) | |||||
x = self.embedding(chars) | |||||
batch_size = x.size(0) | |||||
length = x.size(1) | |||||
if hasattr(self, 'bigram_embedding'): | |||||
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size | |||||
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1) | |||||
self.drop(x) | |||||
x = self.fc1(x) | |||||
feats = self.transformer(x, masks) | |||||
feats = self.fc2(feats) | |||||
losses = self.crf(feats, target, masks.float()) | |||||
pred_dict = {} | |||||
pred_dict['seq_lens'] = seq_lens | |||||
pred_dict['loss'] = torch.mean(losses) | |||||
return pred_dict | |||||
if __name__ == '__main__': | |||||
transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8, | |||||
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4) | |||||
chars = torch.randint(10, size=(4, 7)).long() | |||||
bigrams = torch.randint(10, size=(4, 56)).long() | |||||
seq_lens = torch.ones(4).long()*7 | |||||
target = torch.randint(4, size=(4, 7)) | |||||
print(transformer(chars, target, seq_lens, bigrams)) |