Browse Source

decoder部分的别名

tags/v0.4.10
ChenXin 6 years ago
parent
commit
d6ae241bbb
3 changed files with 36 additions and 36 deletions
  1. +26
    -26
      fastNLP/modules/decoder/crf.py
  2. +5
    -5
      fastNLP/modules/decoder/mlp.py
  3. +5
    -5
      fastNLP/modules/decoder/utils.py

+ 26
- 26
fastNLP/modules/decoder/crf.py View File

@@ -11,7 +11,7 @@ from ..utils import initial_parameter

def allowed_transitions(id2target, encoding_type='bio', include_start_end=False):
"""
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions`
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions`

给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。

@@ -31,7 +31,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False)
id_label_lst = list(id2target.items())
if include_start_end:
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')]
def split_tag_label(from_label):
from_label = from_label.lower()
if from_label in ['start', 'end']:
@@ -41,7 +41,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False)
from_tag = from_label[:1]
from_label = from_label[2:]
return from_tag, from_label
for from_id, from_label in id_label_lst:
if from_label in ['<pad>', '<unk>']:
continue
@@ -93,7 +93,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label
return to_tag in ['end', 'b', 'o']
else:
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag))
elif encoding_type == 'bmes':
"""
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转
@@ -151,7 +151,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label

class ConditionalRandomField(nn.Module):
"""
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.crf.ConditionalRandomField`
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.ConditionalRandomField`

条件随机场。
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。
@@ -163,21 +163,21 @@ class ConditionalRandomField(nn.Module):
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法
:param str initial_method: 初始化方法。见initial_parameter
"""
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None,
initial_method=None):
super(ConditionalRandomField, self).__init__()
self.include_start_end_trans = include_start_end_trans
self.num_tags = num_tags
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags))
if self.include_start_end_trans:
self.start_scores = nn.Parameter(torch.randn(num_tags))
self.end_scores = nn.Parameter(torch.randn(num_tags))
if allowed_transitions is None:
constrain = torch.zeros(num_tags + 2, num_tags + 2)
else:
@@ -185,9 +185,9 @@ class ConditionalRandomField(nn.Module):
for from_tag_id, to_tag_id in allowed_transitions:
constrain[from_tag_id, to_tag_id] = 0
self._constrain = nn.Parameter(constrain, requires_grad=False)
initial_parameter(self, initial_method)
def _normalizer_likelihood(self, logits, mask):
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
@@ -200,21 +200,21 @@ class ConditionalRandomField(nn.Module):
alpha = logits[0]
if self.include_start_end_trans:
alpha = alpha + self.start_scores.view(1, -1)
flip_mask = mask.eq(0)
for i in range(1, seq_len):
emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)
if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1)
return torch.logsumexp(alpha, 1)
def _gold_score(self, logits, tags, mask):
"""
Compute the score for the gold path.
@@ -226,7 +226,7 @@ class ConditionalRandomField(nn.Module):
seq_len, batch_size, _ = logits.size()
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
# trans_socre [L-1, B]
mask = mask.byte()
flip_mask = mask.eq(0)
@@ -243,7 +243,7 @@ class ConditionalRandomField(nn.Module):
score = score + st_scores + ed_scores
# return [B,]
return score
def forward(self, feats, tags, mask):
"""
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。
@@ -258,9 +258,9 @@ class ConditionalRandomField(nn.Module):
mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._gold_score(feats, tags, mask)
return all_path_score - gold_path_score
def viterbi_decode(self, logits, mask, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数

@@ -277,7 +277,7 @@ class ConditionalRandomField(nn.Module):
batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.byte() # L, B
# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0]
@@ -286,7 +286,7 @@ class ConditionalRandomField(nn.Module):
if self.include_start_end_trans:
transitions[n_tags, :n_tags] += self.start_scores.data
transitions[:n_tags, n_tags + 1] += self.end_scores.data
vscore += transitions[n_tags, :n_tags]
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
@@ -297,17 +297,17 @@ class ConditionalRandomField(nn.Module):
vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)
if self.include_start_end_trans:
vscore += transitions[:n_tags, n_tags + 1].view(1, -1)
# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags


+ 5
- 5
fastNLP/modules/decoder/mlp.py View File

@@ -10,7 +10,7 @@ from ..utils import initial_parameter

class MLP(nn.Module):
"""
别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.mlp.MLP`
别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.MLP`

多层感知器

@@ -40,7 +40,7 @@ class MLP(nn.Module):
>>> print(x)
>>> print(y)
"""
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0):
super(MLP, self).__init__()
self.hiddens = nn.ModuleList()
@@ -51,9 +51,9 @@ class MLP(nn.Module):
self.output = nn.Linear(size_layer[i - 1], size_layer[i])
else:
self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i]))
self.dropout = nn.Dropout(p=dropout)
actives = {
'relu': nn.ReLU(),
'tanh': nn.Tanh(),
@@ -82,7 +82,7 @@ class MLP(nn.Module):
else:
raise ValueError("should set activation correctly: {}".format(activation))
initial_parameter(self, initial_method)
def forward(self, x):
"""
:param torch.Tensor x: MLP接受的输入


+ 5
- 5
fastNLP/modules/decoder/utils.py View File

@@ -6,7 +6,7 @@ import torch

def viterbi_decode(logits, transitions, mask=None, unpad=False):
r"""
别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode`
别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.viterbi_decode`

给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数

@@ -30,11 +30,11 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
mask = mask.transpose(0, 1).data.byte() # L, B
else:
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)
# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0]
trans_score = transitions.view(1, n_tags, n_tags).data
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
@@ -44,14 +44,14 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)
# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags


Loading…
Cancel
Save