@@ -1,6 +1,6 @@ | |||||
import torch | import torch | ||||
from ..modules.decoder.MLP import MLP | |||||
from ..modules.decoder.mlp import MLP | |||||
class BaseModel(torch.nn.Module): | class BaseModel(torch.nn.Module): | ||||
@@ -6,7 +6,7 @@ import torch.nn as nn | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..modules import decoder, encoder | from ..modules import decoder, encoder | ||||
from ..modules.decoder.CRF import allowed_transitions | |||||
from ..modules.decoder.crf import allowed_transitions | |||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
from ..core.const import Const as C | from ..core.const import Const as C | ||||
@@ -35,7 +35,7 @@ class SeqLabeling(BaseModel): | |||||
self.Embedding = encoder.embedding.Embedding(init_embed) | self.Embedding = encoder.embedding.Embedding(init_embed) | ||||
self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) | self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size) | ||||
self.Linear = nn.Linear(hidden_size, num_classes) | self.Linear = nn.Linear(hidden_size, num_classes) | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||||
self.Crf = decoder.crf.ConditionalRandomField(num_classes) | |||||
self.mask = None | self.mask = None | ||||
def forward(self, words, seq_len, target): | def forward(self, words, seq_len, target): | ||||
@@ -141,9 +141,9 @@ class AdvSeqLabel(nn.Module): | |||||
self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) | self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) | ||||
if id2words is None: | if id2words is None: | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
else: | else: | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||||
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||||
allowed_transitions=allowed_transitions(id2words, | allowed_transitions=allowed_transitions(id2words, | ||||
encoding_type=encoding_type)) | encoding_type=encoding_type)) | ||||
@@ -32,19 +32,25 @@ from .encoder import * | |||||
from .utils import get_embeddings | from .utils import get_embeddings | ||||
__all__ = [ | __all__ = [ | ||||
"LSTM", | |||||
"Embedding", | |||||
# "BertModel", | |||||
"ConvolutionCharEncoder", | |||||
"LSTMCharEncoder", | |||||
"ConvMaxpool", | "ConvMaxpool", | ||||
"BertModel", | |||||
"Embedding", | |||||
"LSTM", | |||||
"StarTransformer", | |||||
"TransformerEncoder", | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU", | |||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
"AvgPool", | "AvgPool", | ||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
"BiAttention", | |||||
"MLP", | "MLP", | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
"viterbi_decode", | "viterbi_decode", | ||||
"allowed_transitions", | "allowed_transitions", | ||||
] | |||||
] |
@@ -3,12 +3,12 @@ from .pooling import MaxPoolWithMask | |||||
from .pooling import AvgPool | from .pooling import AvgPool | ||||
from .pooling import AvgPoolWithMask | from .pooling import AvgPoolWithMask | ||||
from .attention import MultiHeadAttention, BiAttention | |||||
from .attention import MultiHeadAttention | |||||
__all__ = [ | __all__ = [ | ||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
"AvgPool", | "AvgPool", | ||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
"BiAttention" | |||||
] | ] |
@@ -1,4 +1,3 @@ | |||||
__all__ =["MultiHeadAttention"] | |||||
import math | import math | ||||
import torch | import torch | ||||
@@ -9,12 +8,17 @@ from ..dropout import TimestepDropout | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"MultiHeadAttention" | |||||
] | |||||
class DotAttention(nn.Module): | class DotAttention(nn.Module): | ||||
""" | """ | ||||
.. todo:: | .. todo:: | ||||
补上文档 | 补上文档 | ||||
""" | """ | ||||
def __init__(self, key_size, value_size, dropout=0): | def __init__(self, key_size, value_size, dropout=0): | ||||
super(DotAttention, self).__init__() | super(DotAttention, self).__init__() | ||||
self.key_size = key_size | self.key_size = key_size | ||||
@@ -22,7 +26,7 @@ class DotAttention(nn.Module): | |||||
self.scale = math.sqrt(key_size) | self.scale = math.sqrt(key_size) | ||||
self.drop = nn.Dropout(dropout) | self.drop = nn.Dropout(dropout) | ||||
self.softmax = nn.Softmax(dim=2) | self.softmax = nn.Softmax(dim=2) | ||||
def forward(self, Q, K, V, mask_out=None): | def forward(self, Q, K, V, mask_out=None): | ||||
""" | """ | ||||
@@ -41,6 +45,8 @@ class DotAttention(nn.Module): | |||||
class MultiHeadAttention(nn.Module): | class MultiHeadAttention(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.aggregator.attention.MultiHeadAttention` | |||||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | :param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | ||||
:param key_size: int, 每个head的维度大小。 | :param key_size: int, 每个head的维度大小。 | ||||
@@ -48,13 +54,14 @@ class MultiHeadAttention(nn.Module): | |||||
:param num_head: int,head的数量。 | :param num_head: int,head的数量。 | ||||
:param dropout: float。 | :param dropout: float。 | ||||
""" | """ | ||||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | ||||
super(MultiHeadAttention, self).__init__() | super(MultiHeadAttention, self).__init__() | ||||
self.input_size = input_size | self.input_size = input_size | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
self.num_head = num_head | self.num_head = num_head | ||||
in_size = key_size * num_head | in_size = key_size * num_head | ||||
self.q_in = nn.Linear(input_size, in_size) | self.q_in = nn.Linear(input_size, in_size) | ||||
self.k_in = nn.Linear(input_size, in_size) | self.k_in = nn.Linear(input_size, in_size) | ||||
@@ -64,14 +71,14 @@ class MultiHeadAttention(nn.Module): | |||||
self.out = nn.Linear(value_size * num_head, input_size) | self.out = nn.Linear(value_size * num_head, input_size) | ||||
self.drop = TimestepDropout(dropout) | self.drop = TimestepDropout(dropout) | ||||
self.reset_parameters() | self.reset_parameters() | ||||
def reset_parameters(self): | def reset_parameters(self): | ||||
sqrt = math.sqrt | sqrt = math.sqrt | ||||
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | ||||
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | ||||
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) | nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) | ||||
nn.init.xavier_normal_(self.out.weight) | nn.init.xavier_normal_(self.out.weight) | ||||
def forward(self, Q, K, V, atte_mask_out=None): | def forward(self, Q, K, V, atte_mask_out=None): | ||||
""" | """ | ||||
@@ -87,7 +94,7 @@ class MultiHeadAttention(nn.Module): | |||||
q = self.q_in(Q).view(batch, sq, n_head, d_k) | q = self.q_in(Q).view(batch, sq, n_head, d_k) | ||||
k = self.k_in(K).view(batch, sk, n_head, d_k) | k = self.k_in(K).view(batch, sk, n_head, d_k) | ||||
v = self.v_in(V).view(batch, sk, n_head, d_v) | v = self.v_in(V).view(batch, sk, n_head, d_v) | ||||
# transpose q, k and v to do batch attention | # transpose q, k and v to do batch attention | ||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k) | q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k) | ||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k) | k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k) | ||||
@@ -95,7 +102,7 @@ class MultiHeadAttention(nn.Module): | |||||
if atte_mask_out is not None: | if atte_mask_out is not None: | ||||
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | ||||
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v) | atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v) | ||||
# concat all heads, do output linear | # concat all heads, do output linear | ||||
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | ||||
output = self.drop(self.out(atte)) | output = self.drop(self.out(atte)) | ||||
@@ -104,6 +111,10 @@ class MultiHeadAttention(nn.Module): | |||||
class BiAttention(nn.Module): | class BiAttention(nn.Module): | ||||
r"""Bi Attention module | r"""Bi Attention module | ||||
.. todo:: | |||||
这个模块的负责人来继续完善一下 | |||||
Calculate Bi Attention matrix `e` | Calculate Bi Attention matrix `e` | ||||
.. math:: | .. math:: | ||||
@@ -115,11 +126,11 @@ class BiAttention(nn.Module): | |||||
\end{array} | \end{array} | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(BiAttention, self).__init__() | super(BiAttention, self).__init__() | ||||
self.inf = 10e12 | self.inf = 10e12 | ||||
def forward(self, in_x1, in_x2, x1_len, x2_len): | def forward(self, in_x1, in_x2, x1_len, x2_len): | ||||
""" | """ | ||||
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 | :param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 | ||||
@@ -130,36 +141,36 @@ class BiAttention(nn.Module): | |||||
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 | torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 | ||||
""" | """ | ||||
assert in_x1.size()[0] == in_x2.size()[0] | assert in_x1.size()[0] == in_x2.size()[0] | ||||
assert in_x1.size()[2] == in_x2.size()[2] | assert in_x1.size()[2] == in_x2.size()[2] | ||||
# The batch size and hidden size must be equal. | # The batch size and hidden size must be equal. | ||||
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] | assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] | ||||
# The seq len in in_x and x_len must be equal. | # The seq len in in_x and x_len must be equal. | ||||
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] | assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] | ||||
batch_size = in_x1.size()[0] | batch_size = in_x1.size()[0] | ||||
x1_max_len = in_x1.size()[1] | x1_max_len = in_x1.size()[1] | ||||
x2_max_len = in_x2.size()[1] | x2_max_len = in_x2.size()[1] | ||||
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] | in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] | ||||
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] | attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] | ||||
a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] | a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] | ||||
a_mask = a_mask.view(batch_size, x1_max_len, -1) | a_mask = a_mask.view(batch_size, x1_max_len, -1) | ||||
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] | a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] | ||||
b_mask = x2_len.le(0.5).float() * -self.inf | b_mask = x2_len.le(0.5).float() * -self.inf | ||||
b_mask = b_mask.view(batch_size, -1, x2_max_len) | b_mask = b_mask.view(batch_size, -1, x2_max_len) | ||||
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] | b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] | ||||
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] | attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] | ||||
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] | attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] | ||||
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] | out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] | ||||
attention_b_t = torch.transpose(attention_b, 1, 2) | attention_b_t = torch.transpose(attention_b, 1, 2) | ||||
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] | out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] | ||||
return out_x1, out_x2 | return out_x1, out_x2 | ||||
@@ -173,10 +184,10 @@ class SelfAttention(nn.Module): | |||||
:param float drop: dropout概率,默认值为0.5 | :param float drop: dropout概率,默认值为0.5 | ||||
:param str initial_method: 初始化参数方法 | :param str initial_method: 初始化参数方法 | ||||
""" | """ | ||||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None,): | |||||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ): | |||||
super(SelfAttention, self).__init__() | super(SelfAttention, self).__init__() | ||||
self.attention_hops = attention_hops | self.attention_hops = attention_hops | ||||
self.ws1 = nn.Linear(input_size, attention_unit, bias=False) | self.ws1 = nn.Linear(input_size, attention_unit, bias=False) | ||||
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) | self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) | ||||
@@ -185,7 +196,7 @@ class SelfAttention(nn.Module): | |||||
self.drop = nn.Dropout(drop) | self.drop = nn.Dropout(drop) | ||||
self.tanh = nn.Tanh() | self.tanh = nn.Tanh() | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def _penalization(self, attention): | def _penalization(self, attention): | ||||
""" | """ | ||||
compute the penalization term for attention module | compute the penalization term for attention module | ||||
@@ -199,7 +210,7 @@ class SelfAttention(nn.Module): | |||||
mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] | mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] | ||||
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 | ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 | ||||
return torch.sum(ret) / size[0] | return torch.sum(ret) / size[0] | ||||
def forward(self, input, input_origin): | def forward(self, input, input_origin): | ||||
""" | """ | ||||
:param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵 | :param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵 | ||||
@@ -209,15 +220,14 @@ class SelfAttention(nn.Module): | |||||
""" | """ | ||||
input = input.contiguous() | input = input.contiguous() | ||||
size = input.size() # [bsz, len, nhid] | size = input.size() # [bsz, len, nhid] | ||||
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | ||||
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | ||||
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | ||||
attention = self.ws2(y1).transpose(1, 2).contiguous() | attention = self.ws2(y1).transpose(1, 2).contiguous() | ||||
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | ||||
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | ||||
attention = F.softmax(attention, 2) # [baz ,hop, len] | attention = F.softmax(attention, 2) # [baz ,hop, len] | ||||
return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] | return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] | ||||
@@ -1,7 +1,7 @@ | |||||
from .CRF import ConditionalRandomField | |||||
from .MLP import MLP | |||||
from .crf import ConditionalRandomField | |||||
from .mlp import MLP | |||||
from .utils import viterbi_decode | from .utils import viterbi_decode | ||||
from .CRF import allowed_transitions | |||||
from .crf import allowed_transitions | |||||
__all__ = [ | __all__ = [ | ||||
"MLP", | "MLP", | ||||
@@ -3,10 +3,15 @@ from torch import nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"ConditionalRandomField", | |||||
"allowed_transitions" | |||||
] | |||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.CRF.allowed_transitions` | |||||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` | |||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | ||||
@@ -15,8 +20,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 | :param str encoding_type: 支持"bio", "bmes", "bmeso"。 | ||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | ||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | ||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||||
为False, 返回的结果中不含与开始结尾相关的内容 | |||||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | ||||
""" | """ | ||||
num_tags = len(id2target) | num_tags = len(id2target) | ||||
@@ -27,6 +31,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
id_label_lst = list(id2target.items()) | id_label_lst = list(id2target.items()) | ||||
if include_start_end: | if include_start_end: | ||||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | ||||
def split_tag_label(from_label): | def split_tag_label(from_label): | ||||
from_label = from_label.lower() | from_label = from_label.lower() | ||||
if from_label in ['start', 'end']: | if from_label in ['start', 'end']: | ||||
@@ -36,7 +41,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
from_tag = from_label[:1] | from_tag = from_label[:1] | ||||
from_label = from_label[2:] | from_label = from_label[2:] | ||||
return from_tag, from_label | return from_tag, from_label | ||||
for from_id, from_label in id_label_lst: | for from_id, from_label in id_label_lst: | ||||
if from_label in ['<pad>', '<unk>']: | if from_label in ['<pad>', '<unk>']: | ||||
continue | continue | ||||
@@ -60,7 +65,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
:param str to_label: 比如"PER", "LOC"等label | :param str to_label: 比如"PER", "LOC"等label | ||||
:return: bool,能否跃迁 | :return: bool,能否跃迁 | ||||
""" | """ | ||||
if to_tag=='start' or from_tag=='end': | |||||
if to_tag == 'start' or from_tag == 'end': | |||||
return False | return False | ||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
@@ -83,12 +88,12 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
if from_tag == 'start': | if from_tag == 'start': | ||||
return to_tag in ('b', 'o') | return to_tag in ('b', 'o') | ||||
elif from_tag in ['b', 'i']: | elif from_tag in ['b', 'i']: | ||||
return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label]) | |||||
return any([to_tag in ['end', 'b', 'o'], to_tag == 'i' and from_label == to_label]) | |||||
elif from_tag == 'o': | elif from_tag == 'o': | ||||
return to_tag in ['end', 'b', 'o'] | return to_tag in ['end', 'b', 'o'] | ||||
else: | else: | ||||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | ||||
elif encoding_type == 'bmes': | elif encoding_type == 'bmes': | ||||
""" | """ | ||||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | ||||
@@ -111,9 +116,9 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
if from_tag == 'start': | if from_tag == 'start': | ||||
return to_tag in ['b', 's'] | return to_tag in ['b', 's'] | ||||
elif from_tag == 'b': | elif from_tag == 'b': | ||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag == 'm': | elif from_tag == 'm': | ||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's']: | elif from_tag in ['e', 's']: | ||||
return to_tag in ['b', 's', 'end'] | return to_tag in ['b', 's', 'end'] | ||||
else: | else: | ||||
@@ -122,21 +127,21 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
if from_tag == 'start': | if from_tag == 'start': | ||||
return to_tag in ['b', 's', 'o'] | return to_tag in ['b', 's', 'o'] | ||||
elif from_tag == 'b': | elif from_tag == 'b': | ||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag == 'm': | elif from_tag == 'm': | ||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's', 'o']: | elif from_tag in ['e', 's', 'o']: | ||||
return to_tag in ['b', 's', 'end', 'o'] | return to_tag in ['b', 's', 'end', 'o'] | ||||
else: | else: | ||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | ||||
else: | else: | ||||
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | ||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.CRF.ConditionalRandomField` | |||||
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.crf.ConditionalRandomField` | |||||
条件随机场。 | 条件随机场。 | ||||
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | 提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | ||||
@@ -148,30 +153,31 @@ class ConditionalRandomField(nn.Module): | |||||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | ||||
:param str initial_method: 初始化方法。见initial_parameter | :param str initial_method: 初始化方法。见initial_parameter | ||||
""" | """ | ||||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | ||||
initial_method=None): | initial_method=None): | ||||
super(ConditionalRandomField, self).__init__() | super(ConditionalRandomField, self).__init__() | ||||
self.include_start_end_trans = include_start_end_trans | self.include_start_end_trans = include_start_end_trans | ||||
self.num_tags = num_tags | self.num_tags = num_tags | ||||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | ||||
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
self.start_scores = nn.Parameter(torch.randn(num_tags)) | self.start_scores = nn.Parameter(torch.randn(num_tags)) | ||||
self.end_scores = nn.Parameter(torch.randn(num_tags)) | self.end_scores = nn.Parameter(torch.randn(num_tags)) | ||||
if allowed_transitions is None: | if allowed_transitions is None: | ||||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | constrain = torch.zeros(num_tags + 2, num_tags + 2) | ||||
else: | else: | ||||
constrain = torch.full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) | |||||
constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float) | |||||
for from_tag_id, to_tag_id in allowed_transitions: | for from_tag_id, to_tag_id in allowed_transitions: | ||||
constrain[from_tag_id, to_tag_id] = 0 | constrain[from_tag_id, to_tag_id] = 0 | ||||
self._constrain = nn.Parameter(constrain, requires_grad=False) | self._constrain = nn.Parameter(constrain, requires_grad=False) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def _normalizer_likelihood(self, logits, mask): | def _normalizer_likelihood(self, logits, mask): | ||||
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | """Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
@@ -184,21 +190,21 @@ class ConditionalRandomField(nn.Module): | |||||
alpha = logits[0] | alpha = logits[0] | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.start_scores.view(1, -1) | alpha = alpha + self.start_scores.view(1, -1) | ||||
flip_mask = mask.eq(0) | flip_mask = mask.eq(0) | ||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
emit_score = logits[i].view(batch_size, 1, n_tags) | emit_score = logits[i].view(batch_size, 1, n_tags) | ||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | trans_score = self.trans_m.view(1, n_tags, n_tags) | ||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | ||||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | ||||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.end_scores.view(1, -1) | alpha = alpha + self.end_scores.view(1, -1) | ||||
return torch.logsumexp(alpha, 1) | return torch.logsumexp(alpha, 1) | ||||
def _gold_score(self, logits, tags, mask): | def _gold_score(self, logits, tags, mask): | ||||
""" | """ | ||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
@@ -210,15 +216,15 @@ class ConditionalRandomField(nn.Module): | |||||
seq_len, batch_size, _ = logits.size() | seq_len, batch_size, _ = logits.size() | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | ||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
# trans_socre [L-1, B] | # trans_socre [L-1, B] | ||||
mask = mask.byte() | mask = mask.byte() | ||||
flip_mask = mask.eq(0) | flip_mask = mask.eq(0) | ||||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | |||||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | |||||
# emit_score [L, B] | # emit_score [L, B] | ||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0) | |||||
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) | |||||
# score [L-1, B] | # score [L-1, B] | ||||
score = trans_score + emit_score[:seq_len-1, :] | |||||
score = trans_score + emit_score[:seq_len - 1, :] | |||||
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) | score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | ||||
@@ -227,24 +233,24 @@ class ConditionalRandomField(nn.Module): | |||||
score = score + st_scores + ed_scores | score = score + st_scores + ed_scores | ||||
# return [B,] | # return [B,] | ||||
return score | return score | ||||
def forward(self, feats, tags, mask): | def forward(self, feats, tags, mask): | ||||
""" | """ | ||||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | ||||
:param torch.FloatTensor feats:batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | :param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | ||||
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | :param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | ||||
:return:torch.FloatTensor, (batch_size,) | |||||
:return: torch.FloatTensor, (batch_size,) | |||||
""" | """ | ||||
feats = feats.transpose(0, 1) | feats = feats.transpose(0, 1) | ||||
tags = tags.transpose(0, 1).long() | tags = tags.transpose(0, 1).long() | ||||
mask = mask.transpose(0, 1).float() | mask = mask.transpose(0, 1).float() | ||||
all_path_score = self._normalizer_likelihood(feats, mask) | all_path_score = self._normalizer_likelihood(feats, mask) | ||||
gold_path_score = self._gold_score(feats, tags, mask) | gold_path_score = self._gold_score(feats, tags, mask) | ||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, logits, mask, unpad=False): | def viterbi_decode(self, logits, mask, unpad=False): | ||||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | ||||
@@ -259,9 +265,9 @@ class ConditionalRandomField(nn.Module): | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = logits.size() | batch_size, seq_len, n_tags = logits.size() | ||||
logits = logits.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
logits = logits.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
# dp | # dp | ||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
vscore = logits[0] | vscore = logits[0] | ||||
@@ -269,8 +275,8 @@ class ConditionalRandomField(nn.Module): | |||||
transitions[:n_tags, :n_tags] += self.trans_m.data | transitions[:n_tags, :n_tags] += self.trans_m.data | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
transitions[n_tags, :n_tags] += self.start_scores.data | transitions[n_tags, :n_tags] += self.start_scores.data | ||||
transitions[:n_tags, n_tags+1] += self.end_scores.data | |||||
transitions[:n_tags, n_tags + 1] += self.end_scores.data | |||||
vscore += transitions[n_tags, :n_tags] | vscore += transitions[n_tags, :n_tags] | ||||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | ||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
@@ -280,30 +286,29 @@ class ConditionalRandomField(nn.Module): | |||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | ||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||||
vscore += transitions[:n_tags, n_tags + 1].view(1, -1) | |||||
# backtrace | # backtrace | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | ||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
lens = (mask.long().sum(0) - 1) | lens = (mask.long().sum(0) - 1) | ||||
# idxes [L, B], batched idx from seq_len-1 to 0 | # idxes [L, B], batched idx from seq_len-1 to 0 | ||||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||||
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | ||||
ans_score, last_tags = vscore.max(1) | ans_score, last_tags = vscore.max(1) | ||||
ans[idxes[0], batch_idx] = last_tags | ans[idxes[0], batch_idx] = last_tags | ||||
for i in range(seq_len - 1): | for i in range(seq_len - 1): | ||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | last_tags = vpath[idxes[i], batch_idx, last_tags] | ||||
ans[idxes[i+1], batch_idx] = last_tags | |||||
ans[idxes[i + 1], batch_idx] = last_tags | |||||
ans = ans.transpose(0, 1) | ans = ans.transpose(0, 1) | ||||
if unpad: | if unpad: | ||||
paths = [] | paths = [] | ||||
for idx, seq_len in enumerate(lens): | for idx, seq_len in enumerate(lens): | ||||
paths.append(ans[idx, :seq_len+1].tolist()) | |||||
paths.append(ans[idx, :seq_len + 1].tolist()) | |||||
else: | else: | ||||
paths = ans | paths = ans | ||||
return paths, ans_score | return paths, ans_score | ||||
@@ -3,20 +3,23 @@ import torch.nn as nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"MLP" | |||||
] | |||||
class MLP(nn.Module): | class MLP(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.MLP.MLP` | |||||
别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.mlp.MLP` | |||||
多层感知器 | 多层感知器 | ||||
:param list size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | |||||
:param str or list activation: | |||||
一个字符串或者函数或者字符串跟函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu | |||||
:param str or function output_activation : 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | |||||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | |||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu | |||||
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | |||||
:param str initial_method: 参数初始化方式 | :param str initial_method: 参数初始化方式 | ||||
:param float dropout: dropout概率,默认值为0 | :param float dropout: dropout概率,默认值为0 | ||||
.. note:: | .. note:: | ||||
隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 | 隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 | ||||
如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; | 如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; | ||||
@@ -35,10 +38,8 @@ class MLP(nn.Module): | |||||
>>> y = net(x) | >>> y = net(x) | ||||
>>> print(x) | >>> print(x) | ||||
>>> print(y) | >>> print(y) | ||||
>>> | |||||
""" | """ | ||||
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | ||||
super(MLP, self).__init__() | super(MLP, self).__init__() | ||||
self.hiddens = nn.ModuleList() | self.hiddens = nn.ModuleList() | ||||
@@ -46,12 +47,12 @@ class MLP(nn.Module): | |||||
self.output_activation = output_activation | self.output_activation = output_activation | ||||
for i in range(1, len(size_layer)): | for i in range(1, len(size_layer)): | ||||
if i + 1 == len(size_layer): | if i + 1 == len(size_layer): | ||||
self.output = nn.Linear(size_layer[i-1], size_layer[i]) | |||||
self.output = nn.Linear(size_layer[i - 1], size_layer[i]) | |||||
else: | else: | ||||
self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i])) | |||||
self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i])) | |||||
self.dropout = nn.Dropout(p=dropout) | self.dropout = nn.Dropout(p=dropout) | ||||
actives = { | actives = { | ||||
'relu': nn.ReLU(), | 'relu': nn.ReLU(), | ||||
'tanh': nn.Tanh(), | 'tanh': nn.Tanh(), | ||||
@@ -80,7 +81,7 @@ class MLP(nn.Module): | |||||
else: | else: | ||||
raise ValueError("should set activation correctly: {}".format(activation)) | raise ValueError("should set activation correctly: {}".format(activation)) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.Tensor x: MLP接受的输入 | :param torch.Tensor x: MLP接受的输入 | ||||
@@ -93,16 +94,3 @@ class MLP(nn.Module): | |||||
x = self.output_activation(x) | x = self.output_activation(x) | ||||
x = self.dropout(x) | x = self.dropout(x) | ||||
return x | return x | ||||
if __name__ == '__main__': | |||||
net1 = MLP([5, 10, 5]) | |||||
net2 = MLP([5, 10, 5], 'tanh') | |||||
net3 = MLP([5, 6, 7, 8, 5], 'tanh') | |||||
net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh') | |||||
net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh') | |||||
for net in [net1, net2, net3, net4, net5]: | |||||
x = torch.randn(5, 5) | |||||
y = net(x) | |||||
print(x) | |||||
print(y) |
@@ -1,10 +1,13 @@ | |||||
__all__ = ["viterbi_decode"] | |||||
import torch | import torch | ||||
__all__ = [ | |||||
"viterbi_decode" | |||||
] | |||||
def viterbi_decode(logits, transitions, mask=None, unpad=False): | def viterbi_decode(logits, transitions, mask=None, unpad=False): | ||||
""" | |||||
别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode | |||||
r""" | |||||
别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode` | |||||
给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | ||||
@@ -20,18 +23,19 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = logits.size() | batch_size, seq_len, n_tags = logits.size() | ||||
assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \ | |||||
"compatible." | |||||
assert n_tags == transitions.size(0) and n_tags == transitions.size( | |||||
1), "The shapes of transitions and feats are not " \ | |||||
"compatible." | |||||
logits = logits.transpose(0, 1).data # L, B, H | logits = logits.transpose(0, 1).data # L, B, H | ||||
if mask is not None: | if mask is not None: | ||||
mask = mask.transpose(0, 1).data.byte() # L, B | mask = mask.transpose(0, 1).data.byte() # L, B | ||||
else: | else: | ||||
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | ||||
# dp | # dp | ||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
vscore = logits[0] | vscore = logits[0] | ||||
trans_score = transitions.view(1, n_tags, n_tags).data | trans_score = transitions.view(1, n_tags, n_tags).data | ||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
prev_score = vscore.view(batch_size, n_tags, 1) | prev_score = vscore.view(batch_size, n_tags, 1) | ||||
@@ -41,14 +45,14 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | ||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | vscore.masked_fill(mask[i].view(batch_size, 1), 0) | ||||
# backtrace | # backtrace | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | ||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
lens = (mask.long().sum(0) - 1) | lens = (mask.long().sum(0) - 1) | ||||
# idxes [L, B], batched idx from seq_len-1 to 0 | # idxes [L, B], batched idx from seq_len-1 to 0 | ||||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | ||||
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | ||||
ans_score, last_tags = vscore.max(1) | ans_score, last_tags = vscore.max(1) | ||||
ans[idxes[0], batch_idx] = last_tags | ans[idxes[0], batch_idx] = last_tags | ||||
@@ -62,4 +66,4 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
paths.append(ans[idx, :seq_len + 1].tolist()) | paths.append(ans[idx, :seq_len + 1].tolist()) | ||||
else: | else: | ||||
paths = ans | paths = ans | ||||
return paths, ans_score | |||||
return paths, ans_score |
@@ -1,11 +1,29 @@ | |||||
from .bert import BertModel | |||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||||
from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
from .embedding import Embedding | from .embedding import Embedding | ||||
from .lstm import LSTM | from .lstm import LSTM | ||||
from .bert import BertModel | |||||
from .star_transformer import StarTransformer | |||||
from .transformer import TransformerEncoder | |||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||||
__all__ = [ | __all__ = [ | ||||
"LSTM", | |||||
"Embedding", | |||||
# "BertModel", | |||||
"ConvolutionCharEncoder", | |||||
"LSTMCharEncoder", | |||||
"ConvMaxpool", | "ConvMaxpool", | ||||
"BertModel" | |||||
"Embedding", | |||||
"LSTM", | |||||
"StarTransformer", | |||||
"TransformerEncoder", | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU" | |||||
] | ] |
@@ -1,8 +1,13 @@ | |||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn as nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"ConvolutionCharEncoder", | |||||
"LSTMCharEncoder" | |||||
] | |||||
# from torch.nn.init import xavier_uniform | # from torch.nn.init import xavier_uniform | ||||
class ConvolutionCharEncoder(nn.Module): | class ConvolutionCharEncoder(nn.Module): | ||||
@@ -10,20 +15,22 @@ class ConvolutionCharEncoder(nn.Module): | |||||
别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.ConvolutionCharEncoder` | 别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.ConvolutionCharEncoder` | ||||
char级别的卷积编码器. | char级别的卷积编码器. | ||||
:param int char_emb_size: char级别embedding的维度. Default: 50 | :param int char_emb_size: char级别embedding的维度. Default: 50 | ||||
例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | |||||
:例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | |||||
:param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter. | :param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter. | ||||
:param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核. | :param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核. | ||||
:param initial_method: 初始化参数的方式, 默认为`xavier normal` | :param initial_method: 初始化参数的方式, 默认为`xavier normal` | ||||
""" | """ | ||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): | def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): | ||||
super(ConvolutionCharEncoder, self).__init__() | super(ConvolutionCharEncoder, self).__init__() | ||||
self.convs = nn.ModuleList([ | self.convs = nn.ModuleList([ | ||||
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | ||||
for i in range(len(kernels))]) | for i in range(len(kernels))]) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding | :param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding | ||||
@@ -34,7 +41,7 @@ class ConvolutionCharEncoder(nn.Module): | |||||
x = x.transpose(2, 3) | x = x.transpose(2, 3) | ||||
# [batch_size*sent_length, channel, height, width] | # [batch_size*sent_length, channel, height, width] | ||||
return self._convolute(x).unsqueeze(2) | return self._convolute(x).unsqueeze(2) | ||||
def _convolute(self, x): | def _convolute(self, x): | ||||
feats = [] | feats = [] | ||||
for conv in self.convs: | for conv in self.convs: | ||||
@@ -50,7 +57,14 @@ class ConvolutionCharEncoder(nn.Module): | |||||
class LSTMCharEncoder(nn.Module): | class LSTMCharEncoder(nn.Module): | ||||
"""char级别基于LSTM的encoder.""" | |||||
""" | |||||
别名::class:`fastNLP.modules.LSTMCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.LSTMCharEncoder` | |||||
char级别基于LSTM的encoder. | |||||
""" | |||||
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | ||||
""" | """ | ||||
:param int char_emb_size: char级别embedding的维度. Default: 50 | :param int char_emb_size: char级别embedding的维度. Default: 50 | ||||
@@ -60,14 +74,14 @@ class LSTMCharEncoder(nn.Module): | |||||
""" | """ | ||||
super(LSTMCharEncoder, self).__init__() | super(LSTMCharEncoder, self).__init__() | ||||
self.hidden_size = char_emb_size if hidden_size is None else hidden_size | self.hidden_size = char_emb_size if hidden_size is None else hidden_size | ||||
self.lstm = nn.LSTM(input_size=char_emb_size, | self.lstm = nn.LSTM(input_size=char_emb_size, | ||||
hidden_size=self.hidden_size, | hidden_size=self.hidden_size, | ||||
num_layers=1, | num_layers=1, | ||||
bias=True, | bias=True, | ||||
batch_first=True) | batch_first=True) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding | :param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding | ||||
@@ -78,6 +92,6 @@ class LSTMCharEncoder(nn.Module): | |||||
h0 = nn.init.orthogonal_(h0) | h0 = nn.init.orthogonal_(h0) | ||||
c0 = torch.empty(1, batch_size, self.hidden_size) | c0 = torch.empty(1, batch_size, self.hidden_size) | ||||
c0 = nn.init.orthogonal_(c0) | c0 = nn.init.orthogonal_(c0) | ||||
_, hidden = self.lstm(x, (h0, c0)) | _, hidden = self.lstm(x, (h0, c0)) | ||||
return hidden[0].squeeze().unsqueeze(2) | return hidden[0].squeeze().unsqueeze(2) |
@@ -1,12 +1,13 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"ConvMaxpool" | |||||
] | |||||
class ConvMaxpool(nn.Module): | class ConvMaxpool(nn.Module): | ||||
""" | """ | ||||
@@ -27,22 +28,24 @@ class ConvMaxpool(nn.Module): | |||||
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | ||||
:param str initial_method: str。 | :param str initial_method: str。 | ||||
""" | """ | ||||
def __init__(self, in_channels, out_channels, kernel_sizes, | def __init__(self, in_channels, out_channels, kernel_sizes, | ||||
stride=1, padding=0, dilation=1, | stride=1, padding=0, dilation=1, | ||||
groups=1, bias=True, activation="relu", initial_method=None): | groups=1, bias=True, activation="relu", initial_method=None): | ||||
super(ConvMaxpool, self).__init__() | super(ConvMaxpool, self).__init__() | ||||
# convolution | # convolution | ||||
if isinstance(kernel_sizes, (list, tuple, int)): | if isinstance(kernel_sizes, (list, tuple, int)): | ||||
if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | ||||
out_channels = [out_channels] | out_channels = [out_channels] | ||||
kernel_sizes = [kernel_sizes] | kernel_sizes = [kernel_sizes] | ||||
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | ||||
assert len(out_channels)==len(kernel_sizes), "The number of out_channels should be equal to the number" \ | |||||
" of kernel_sizes." | |||||
assert len(out_channels) == len( | |||||
kernel_sizes), "The number of out_channels should be equal to the number" \ | |||||
" of kernel_sizes." | |||||
else: | else: | ||||
raise ValueError("The type of out_channels and kernel_sizes should be the same.") | raise ValueError("The type of out_channels and kernel_sizes should be the same.") | ||||
self.convs = nn.ModuleList([nn.Conv1d( | self.convs = nn.ModuleList([nn.Conv1d( | ||||
in_channels=in_channels, | in_channels=in_channels, | ||||
out_channels=oc, | out_channels=oc, | ||||
@@ -53,11 +56,11 @@ class ConvMaxpool(nn.Module): | |||||
groups=groups, | groups=groups, | ||||
bias=bias) | bias=bias) | ||||
for oc, ks in zip(out_channels, kernel_sizes)]) | for oc, ks in zip(out_channels, kernel_sizes)]) | ||||
else: | else: | ||||
raise Exception( | raise Exception( | ||||
'Incorrect kernel sizes: should be list, tuple or int') | 'Incorrect kernel sizes: should be list, tuple or int') | ||||
# activation function | # activation function | ||||
if activation == 'relu': | if activation == 'relu': | ||||
self.activation = F.relu | self.activation = F.relu | ||||
@@ -68,9 +71,9 @@ class ConvMaxpool(nn.Module): | |||||
else: | else: | ||||
raise Exception( | raise Exception( | ||||
"Undefined activation function: choose from: relu, tanh, sigmoid") | "Undefined activation function: choose from: relu, tanh, sigmoid") | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x, mask=None): | def forward(self, x, mask=None): | ||||
""" | """ | ||||
@@ -83,9 +86,9 @@ class ConvMaxpool(nn.Module): | |||||
# convolution | # convolution | ||||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | ||||
if mask is not None: | if mask is not None: | ||||
mask = mask.unsqueeze(1) # B x 1 x L | |||||
mask = mask.unsqueeze(1) # B x 1 x L | |||||
xs = [x.masked_fill_(mask, float('-inf')) for x in xs] | xs = [x.masked_fill_(mask, float('-inf')) for x in xs] | ||||
# max-pooling | # max-pooling | ||||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | ||||
for i in xs] # [[N, C], ...] | for i in xs] # [[N, C], ...] | ||||
return torch.cat(xs, dim=-1) # [N, C] | |||||
return torch.cat(xs, dim=-1) # [N, C] |
@@ -1,14 +1,19 @@ | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from ..utils import get_embeddings | from ..utils import get_embeddings | ||||
__all__ = [ | |||||
"Embedding" | |||||
] | |||||
class Embedding(nn.Embedding): | class Embedding(nn.Embedding): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.Embedding` :class:`fastNLP.modules.encoder.embedding.Embedding` | 别名::class:`fastNLP.modules.Embedding` :class:`fastNLP.modules.encoder.embedding.Embedding` | ||||
Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" | Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" | ||||
def __init__(self, init_embed, padding_idx=None, dropout=0.0, sparse=False, max_norm=None, norm_type=2, | def __init__(self, init_embed, padding_idx=None, dropout=0.0, sparse=False, max_norm=None, norm_type=2, | ||||
scale_grad_by_freq=False): | |||||
scale_grad_by_freq=False): | |||||
""" | """ | ||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), | ||||
@@ -22,14 +27,14 @@ class Embedding(nn.Embedding): | |||||
""" | """ | ||||
embed = get_embeddings(init_embed) | embed = get_embeddings(init_embed) | ||||
num_embeddings, embedding_dim = embed.weight.size() | num_embeddings, embedding_dim = embed.weight.size() | ||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, | super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, | ||||
max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, | |||||
sparse=sparse, _weight=embed.weight.data) | |||||
max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, | |||||
sparse=sparse, _weight=embed.weight.data) | |||||
del embed | del embed | ||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param torch.LongTensor x: [batch, seq_len] | :param torch.LongTensor x: [batch, seq_len] | ||||
@@ -1,4 +1,5 @@ | |||||
"""轻量封装的 Pytorch LSTM 模块. | |||||
""" | |||||
轻量封装的 Pytorch LSTM 模块. | |||||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | 可在 forward 时传入序列的长度, 自动对padding做合适的处理. | ||||
""" | """ | ||||
import torch | import torch | ||||
@@ -7,6 +8,10 @@ import torch.nn.utils.rnn as rnn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"LSTM" | |||||
] | |||||
class LSTM(nn.Module): | class LSTM(nn.Module): | ||||
""" | """ | ||||
@@ -23,6 +28,7 @@ class LSTM(nn.Module): | |||||
:(batch, seq, feature). Default: ``False`` | :(batch, seq, feature). Default: ``False`` | ||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | ||||
""" | """ | ||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | ||||
bidirectional=False, bias=True, initial_method=None): | bidirectional=False, bias=True, initial_method=None): | ||||
super(LSTM, self).__init__() | super(LSTM, self).__init__() | ||||
@@ -30,7 +36,7 @@ class LSTM(nn.Module): | |||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | ||||
dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x, seq_len=None, h0=None, c0=None): | def forward(self, x, seq_len=None, h0=None, c0=None): | ||||
""" | """ | ||||
@@ -1,9 +1,14 @@ | |||||
"""Star-Transformer 的encoder部分的 Pytorch 实现 | |||||
""" | """ | ||||
Star-Transformer 的encoder部分的 Pytorch 实现 | |||||
""" | |||||
import numpy as NP | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
import numpy as NP | |||||
__all__ = [ | |||||
"StarTransformer" | |||||
] | |||||
class StarTransformer(nn.Module): | class StarTransformer(nn.Module): | ||||
@@ -24,10 +29,11 @@ class StarTransformer(nn.Module): | |||||
模型会为输入序列加上position embedding。 | 模型会为输入序列加上position embedding。 | ||||
若为`None`,忽略加上position embedding的步骤. Default: `None` | 若为`None`,忽略加上position embedding的步骤. Default: `None` | ||||
""" | """ | ||||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | ||||
super(StarTransformer, self).__init__() | super(StarTransformer, self).__init__() | ||||
self.iters = num_layers | self.iters = num_layers | ||||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | ||||
self.ring_att = nn.ModuleList( | self.ring_att = nn.ModuleList( | ||||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | ||||
@@ -35,12 +41,12 @@ class StarTransformer(nn.Module): | |||||
self.star_att = nn.ModuleList( | self.star_att = nn.ModuleList( | ||||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | ||||
for _ in range(self.iters)]) | for _ in range(self.iters)]) | ||||
if max_len is not None: | if max_len is not None: | ||||
self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) | self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) | ||||
else: | else: | ||||
self.pos_emb = None | self.pos_emb = None | ||||
def forward(self, data, mask): | def forward(self, data, mask): | ||||
""" | """ | ||||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | :param FloatTensor data: [batch, length, hidden] 输入的序列 | ||||
@@ -50,20 +56,21 @@ class StarTransformer(nn.Module): | |||||
[batch, hidden] 全局 relay 节点, 详见论文 | [batch, hidden] 全局 relay 节点, 详见论文 | ||||
""" | """ | ||||
def norm_func(f, x): | def norm_func(f, x): | ||||
# B, H, L, 1 | # B, H, L, 1 | ||||
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | ||||
B, L, H = data.size() | B, L, H = data.size() | ||||
mask = (mask == 0) # flip the mask for masked_fill_ | |||||
mask = (mask == 0) # flip the mask for masked_fill_ | |||||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | ||||
embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1 | |||||
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | |||||
if self.pos_emb: | if self.pos_emb: | ||||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device)\ | |||||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | |||||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||||
embs = embs + P | embs = embs + P | ||||
nodes = embs | nodes = embs | ||||
relay = embs.mean(2, keepdim=True) | relay = embs.mean(2, keepdim=True) | ||||
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | ||||
@@ -72,11 +79,11 @@ class StarTransformer(nn.Module): | |||||
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | ||||
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | ||||
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | ||||
nodes = nodes.masked_fill_(ex_mask, 0) | nodes = nodes.masked_fill_(ex_mask, 0) | ||||
nodes = nodes.view(B, H, L).permute(0, 2, 1) | nodes = nodes.view(B, H, L).permute(0, 2, 1) | ||||
return nodes, relay.view(B, H) | return nodes, relay.view(B, H) | ||||
@@ -89,37 +96,37 @@ class _MSA1(nn.Module): | |||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | ||||
self.drop = nn.Dropout(dropout) | self.drop = nn.Dropout(dropout) | ||||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | ||||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | ||||
def forward(self, x, ax=None): | def forward(self, x, ax=None): | ||||
# x: B, H, L, 1, ax : B, H, X, L append features | # x: B, H, L, 1, ax : B, H, X, L append features | ||||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | ||||
B, H, L, _ = x.shape | B, H, L, _ = x.shape | ||||
q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) | q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) | ||||
if ax is not None: | if ax is not None: | ||||
aL = ax.shape[2] | aL = ax.shape[2] | ||||
ak = self.WK(ax).view(B, nhead, head_dim, aL, L) | ak = self.WK(ax).view(B, nhead, head_dim, aL, L) | ||||
av = self.WV(ax).view(B, nhead, head_dim, aL, L) | av = self.WV(ax).view(B, nhead, head_dim, aL, L) | ||||
q = q.view(B, nhead, head_dim, 1, L) | q = q.view(B, nhead, head_dim, 1, L) | ||||
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
if ax is not None: | if ax is not None: | ||||
k = torch.cat([k, ak], 3) | k = torch.cat([k, ak], 3) | ||||
v = torch.cat([v, av], 3) | v = torch.cat([v, av], 3) | ||||
alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U | alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U | ||||
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) | att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) | ||||
ret = self.WO(att) | ret = self.WO(att) | ||||
return ret | return ret | ||||
@@ -131,19 +138,19 @@ class _MSA2(nn.Module): | |||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | ||||
self.drop = nn.Dropout(dropout) | self.drop = nn.Dropout(dropout) | ||||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | ||||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | ||||
def forward(self, x, y, mask=None): | def forward(self, x, y, mask=None): | ||||
# x: B, H, 1, 1, 1 y: B H L 1 | # x: B, H, 1, 1, 1 y: B H L 1 | ||||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | ||||
B, H, L, _ = y.shape | B, H, L, _ = y.shape | ||||
q, k, v = self.WQ(x), self.WK(y), self.WV(y) | q, k, v = self.WQ(x), self.WK(y), self.WV(y) | ||||
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h | q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h | ||||
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L | k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L | ||||
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h | v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h | ||||
@@ -3,6 +3,10 @@ from torch import nn | |||||
from ..aggregator.attention import MultiHeadAttention | from ..aggregator.attention import MultiHeadAttention | ||||
from ..dropout import TimestepDropout | from ..dropout import TimestepDropout | ||||
__all__ = [ | |||||
"TransformerEncoder" | |||||
] | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
""" | """ | ||||
@@ -19,6 +23,7 @@ class TransformerEncoder(nn.Module): | |||||
:param int num_head: head的数量。 | :param int num_head: head的数量。 | ||||
:param float dropout: dropout概率. Default: 0.1 | :param float dropout: dropout概率. Default: 0.1 | ||||
""" | """ | ||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | ||||
super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
@@ -27,9 +32,9 @@ class TransformerEncoder(nn.Module): | |||||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | ||||
nn.ReLU(), | nn.ReLU(), | ||||
nn.Linear(inner_size, model_size), | nn.Linear(inner_size, model_size), | ||||
TimestepDropout(dropout),) | |||||
TimestepDropout(dropout), ) | |||||
self.norm2 = nn.LayerNorm(model_size) | self.norm2 = nn.LayerNorm(model_size) | ||||
def forward(self, input, seq_mask=None, atte_mask_out=None): | def forward(self, input, seq_mask=None, atte_mask_out=None): | ||||
""" | """ | ||||
@@ -44,11 +49,11 @@ class TransformerEncoder(nn.Module): | |||||
output = self.norm2(output + norm_atte) | output = self.norm2(output + norm_atte) | ||||
output *= seq_mask | output *= seq_mask | ||||
return output | return output | ||||
def __init__(self, num_layers, **kargs): | def __init__(self, num_layers, **kargs): | ||||
super(TransformerEncoder, self).__init__() | super(TransformerEncoder, self).__init__() | ||||
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | ||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
""" | """ | ||||
:param x: [batch, seq_len, model_size] 输入序列 | :param x: [batch, seq_len, model_size] 输入序列 | ||||
@@ -60,8 +65,8 @@ class TransformerEncoder(nn.Module): | |||||
if seq_mask is None: | if seq_mask is None: | ||||
atte_mask_out = None | atte_mask_out = None | ||||
else: | else: | ||||
atte_mask_out = (seq_mask < 1)[:,None,:] | |||||
seq_mask = seq_mask[:,:,None] | |||||
atte_mask_out = (seq_mask < 1)[:, None, :] | |||||
seq_mask = seq_mask[:, :, None] | |||||
for layer in self.layers: | for layer in self.layers: | ||||
output = layer(output, seq_mask, atte_mask_out) | output = layer(output, seq_mask, atte_mask_out) | ||||
return output | return output |
@@ -1,9 +1,9 @@ | |||||
"""Variational RNN 的 Pytorch 实现 | |||||
""" | |||||
Variational RNN 的 Pytorch 实现 | |||||
""" | """ | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | ||||
from ..utils import initial_parameter | |||||
try: | try: | ||||
from torch import flip | from torch import flip | ||||
@@ -14,18 +14,27 @@ except ImportError: | |||||
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | ||||
return x[tuple(indices)] | return x[tuple(indices)] | ||||
from ..utils import initial_parameter | |||||
__all__ = [ | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU" | |||||
] | |||||
class VarRnnCellWrapper(nn.Module): | class VarRnnCellWrapper(nn.Module): | ||||
"""Wrapper for normal RNN Cells, make it support variational dropout | |||||
""" | """ | ||||
Wrapper for normal RNN Cells, make it support variational dropout | |||||
""" | |||||
def __init__(self, cell, hidden_size, input_p, hidden_p): | def __init__(self, cell, hidden_size, input_p, hidden_p): | ||||
super(VarRnnCellWrapper, self).__init__() | super(VarRnnCellWrapper, self).__init__() | ||||
self.cell = cell | self.cell = cell | ||||
self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
self.input_p = input_p | self.input_p = input_p | ||||
self.hidden_p = hidden_p | self.hidden_p = hidden_p | ||||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | ||||
""" | """ | ||||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | :param PackedSequence input_x: [seq_len, batch_size, input_size] | ||||
@@ -37,11 +46,13 @@ class VarRnnCellWrapper(nn.Module): | |||||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | ||||
for other RNN, h_n, [batch_size, hidden_size] | for other RNN, h_n, [batch_size, hidden_size] | ||||
""" | """ | ||||
def get_hi(hi, h0, size): | def get_hi(hi, h0, size): | ||||
h0_size = size - hi.size(0) | h0_size = size - hi.size(0) | ||||
if h0_size > 0: | if h0_size > 0: | ||||
return torch.cat([hi, h0[:h0_size]], dim=0) | return torch.cat([hi, h0[:h0_size]], dim=0) | ||||
return hi[:size] | return hi[:size] | ||||
is_lstm = isinstance(hidden, tuple) | is_lstm = isinstance(hidden, tuple) | ||||
input, batch_sizes = input_x.data, input_x.batch_sizes | input, batch_sizes = input_x.data, input_x.batch_sizes | ||||
output = [] | output = [] | ||||
@@ -52,7 +63,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
else: | else: | ||||
batch_iter = batch_sizes | batch_iter = batch_sizes | ||||
idx = 0 | idx = 0 | ||||
if is_lstm: | if is_lstm: | ||||
hn = (hidden[0].clone(), hidden[1].clone()) | hn = (hidden[0].clone(), hidden[1].clone()) | ||||
else: | else: | ||||
@@ -60,10 +71,10 @@ class VarRnnCellWrapper(nn.Module): | |||||
hi = hidden | hi = hidden | ||||
for size in batch_iter: | for size in batch_iter: | ||||
if is_reversed: | if is_reversed: | ||||
input_i = input[idx-size: idx] * mask_x[:size] | |||||
input_i = input[idx - size: idx] * mask_x[:size] | |||||
idx -= size | idx -= size | ||||
else: | else: | ||||
input_i = input[idx: idx+size] * mask_x[:size] | |||||
input_i = input[idx: idx + size] * mask_x[:size] | |||||
idx += size | idx += size | ||||
mask_hi = mask_h[:size] | mask_hi = mask_h[:size] | ||||
if is_lstm: | if is_lstm: | ||||
@@ -78,7 +89,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
hi = cell(input_i, hi) | hi = cell(input_i, hi) | ||||
hn[:size] = hi | hn[:size] = hi | ||||
output.append(hi) | output.append(hi) | ||||
if is_reversed: | if is_reversed: | ||||
output = list(reversed(output)) | output = list(reversed(output)) | ||||
output = torch.cat(output, dim=0) | output = torch.cat(output, dim=0) | ||||
@@ -86,7 +97,9 @@ class VarRnnCellWrapper(nn.Module): | |||||
class VarRNNBase(nn.Module): | class VarRNNBase(nn.Module): | ||||
"""Variational Dropout RNN 实现. | |||||
""" | |||||
Variational Dropout RNN 实现. | |||||
论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | 论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | ||||
https://arxiv.org/abs/1512.05287`. | https://arxiv.org/abs/1512.05287`. | ||||
@@ -102,7 +115,7 @@ class VarRNNBase(nn.Module): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | ||||
bias=True, batch_first=False, | bias=True, batch_first=False, | ||||
input_dropout=0, hidden_dropout=0, bidirectional=False): | input_dropout=0, hidden_dropout=0, bidirectional=False): | ||||
@@ -125,7 +138,7 @@ class VarRNNBase(nn.Module): | |||||
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | ||||
initial_parameter(self) | initial_parameter(self) | ||||
self.is_lstm = (self.mode == "LSTM") | self.is_lstm = (self.mode == "LSTM") | ||||
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | ||||
is_lstm = self.is_lstm | is_lstm = self.is_lstm | ||||
idx = self.num_directions * n_layer + n_direction | idx = self.num_directions * n_layer + n_direction | ||||
@@ -133,7 +146,7 @@ class VarRNNBase(nn.Module): | |||||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | ||||
output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | ||||
return output_x, hidden_x | return output_x, hidden_x | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
""" | """ | ||||
@@ -152,19 +165,19 @@ class VarRNNBase(nn.Module): | |||||
else: | else: | ||||
max_batch_size = int(input.batch_sizes[0]) | max_batch_size = int(input.batch_sizes[0]) | ||||
input, batch_sizes = input.data, input.batch_sizes | input, batch_sizes = input.data, input.batch_sizes | ||||
if hx is None: | if hx is None: | ||||
hx = x.new_zeros(self.num_layers * self.num_directions, | hx = x.new_zeros(self.num_layers * self.num_directions, | ||||
max_batch_size, self.hidden_size, requires_grad=True) | max_batch_size, self.hidden_size, requires_grad=True) | ||||
if is_lstm: | if is_lstm: | ||||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | ||||
mask_x = x.new_ones((max_batch_size, self.input_size)) | mask_x = x.new_ones((max_batch_size, self.input_size)) | ||||
mask_out = x.new_ones((max_batch_size, self.hidden_size * self.num_directions)) | mask_out = x.new_ones((max_batch_size, self.hidden_size * self.num_directions)) | ||||
mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) | mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) | ||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | ||||
hidden = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | hidden = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | ||||
if is_lstm: | if is_lstm: | ||||
cellstate = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | cellstate = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | ||||
@@ -183,18 +196,19 @@ class VarRNNBase(nn.Module): | |||||
else: | else: | ||||
hidden[idx] = hidden_x | hidden[idx] = hidden_x | ||||
x = torch.cat(output_list, dim=-1) | x = torch.cat(output_list, dim=-1) | ||||
if is_lstm: | if is_lstm: | ||||
hidden = (hidden, cellstate) | hidden = (hidden, cellstate) | ||||
if is_packed: | if is_packed: | ||||
output = PackedSequence(x, batch_sizes) | output = PackedSequence(x, batch_sizes) | ||||
else: | else: | ||||
x = PackedSequence(x, batch_sizes) | x = PackedSequence(x, batch_sizes) | ||||
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | ||||
return output, hidden | return output, hidden | ||||
class VarLSTM(VarRNNBase): | class VarLSTM(VarRNNBase): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.variational_rnn.VarLSTM` | 别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.variational_rnn.VarLSTM` | ||||
@@ -211,10 +225,10 @@ class VarLSTM(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarLSTM, self).forward(x, hx) | return super(VarLSTM, self).forward(x, hx) | ||||
@@ -235,13 +249,14 @@ class VarRNN(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarRNN, self).forward(x, hx) | return super(VarRNN, self).forward(x, hx) | ||||
class VarGRU(VarRNNBase): | class VarGRU(VarRNNBase): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.variational_rnn.VarGRU` | 别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.variational_rnn.VarGRU` | ||||
@@ -258,10 +273,9 @@ class VarGRU(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarGRU, self).forward(x, hx) | return super(VarGRU, self).forward(x, hx) | ||||
@@ -3,7 +3,7 @@ import torch | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from fastNLP.modules.decoder.mlp import MLP | |||||
from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask | from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask | ||||
@@ -120,8 +120,8 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
return {'pred_tags': pred_tags} | return {'pred_tags': pred_tags} | ||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
from fastNLP.modules.decoder.crf import ConditionalRandomField | |||||
from fastNLP.modules.decoder.crf import allowed_transitions | |||||
class CWSBiLSTMCRF(BaseModel): | class CWSBiLSTMCRF(BaseModel): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | ||||
@@ -10,8 +10,8 @@ from torch import nn | |||||
import torch | import torch | ||||
# from fastNLP.modules.encoder.transformer import TransformerEncoder | # from fastNLP.modules.encoder.transformer import TransformerEncoder | ||||
from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder | from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder | ||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask | |||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
from fastNLP.modules.decoder.crf import ConditionalRandomField,seq_len_to_byte_mask | |||||
from fastNLP.modules.decoder.crf import allowed_transitions | |||||
class TransformerCWS(nn.Module): | class TransformerCWS(nn.Module): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | ||||
@@ -7,7 +7,7 @@ from fastNLP.io.config_io import ConfigSection | |||||
from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader | from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from fastNLP.modules.decoder.mlp import MLP | |||||
from fastNLP.modules.encoder.embedding import Embedding as Embedding | from fastNLP.modules.encoder.embedding import Embedding as Embedding | ||||
from fastNLP.modules.encoder.lstm import LSTM | from fastNLP.modules.encoder.lstm import LSTM | ||||
@@ -5,7 +5,7 @@ import unittest | |||||
class TestCRF(unittest.TestCase): | class TestCRF(unittest.TestCase): | ||||
def test_case1(self): | def test_case1(self): | ||||
# 检查allowed_transitions()能否正确使用 | # 检查allowed_transitions()能否正确使用 | ||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
from fastNLP.modules.decoder.crf import allowed_transitions | |||||
id2label = {0: 'B', 1: 'I', 2:'O'} | id2label = {0: 'B', 1: 'I', 2:'O'} | ||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | ||||
@@ -43,7 +43,7 @@ class TestCRF(unittest.TestCase): | |||||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | ||||
pass | pass | ||||
# import torch | # import torch | ||||
# from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask | |||||
# from fastNLP.modules.decoder.crf import seq_len_to_byte_mask | |||||
# | # | ||||
# labels = ['O'] | # labels = ['O'] | ||||
# for label in ['X', 'Y']: | # for label in ['X', 'Y']: | ||||
@@ -63,7 +63,7 @@ class TestCRF(unittest.TestCase): | |||||
# mask = seq_len_to_byte_mask(seq_lens) | # mask = seq_len_to_byte_mask(seq_lens) | ||||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | # allen_res = allen_CRF.viterbi_tags(logits, mask) | ||||
# | # | ||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||||
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | ||||
# fast_CRF.trans_m = trans_m | # fast_CRF.trans_m = trans_m | ||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) | # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) | ||||
@@ -91,7 +91,7 @@ class TestCRF(unittest.TestCase): | |||||
# mask = seq_len_to_byte_mask(seq_lens) | # mask = seq_len_to_byte_mask(seq_lens) | ||||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | # allen_res = allen_CRF.viterbi_tags(logits, mask) | ||||
# | # | ||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||||
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | ||||
# encoding_type='BMES')) | # encoding_type='BMES')) | ||||
# fast_CRF.trans_m = trans_m | # fast_CRF.trans_m = trans_m | ||||
@@ -104,7 +104,7 @@ class TestCRF(unittest.TestCase): | |||||
def test_case3(self): | def test_case3(self): | ||||
# 测试crf的loss不会出现负数 | # 测试crf的loss不会出现负数 | ||||
import torch | import torch | ||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
from fastNLP.modules.decoder.crf import ConditionalRandomField | |||||
from fastNLP.core.utils import seq_len_to_mask | from fastNLP.core.utils import seq_len_to_mask | ||||
from torch import optim | from torch import optim | ||||
from torch import nn | from torch import nn | ||||