From d6ae241bbb51df3c8636331f4fc4607741cd3dd7 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Fri, 12 Jul 2019 10:49:18 +0800 Subject: [PATCH] =?UTF-8?q?decoder=E9=83=A8=E5=88=86=E7=9A=84=E5=88=AB?= =?UTF-8?q?=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/decoder/crf.py | 52 ++++++++++++++++---------------- fastNLP/modules/decoder/mlp.py | 10 +++--- fastNLP/modules/decoder/utils.py | 10 +++--- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index c0717d6f..7c496868 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -11,7 +11,7 @@ from ..utils import initial_parameter def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): """ - 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` + 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 @@ -31,7 +31,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) id_label_lst = list(id2target.items()) if include_start_end: id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] - + def split_tag_label(from_label): from_label = from_label.lower() if from_label in ['start', 'end']: @@ -41,7 +41,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) from_tag = from_label[:1] from_label = from_label[2:] return from_tag, from_label - + for from_id, from_label in id_label_lst: if from_label in ['', '']: continue @@ -93,7 +93,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label return to_tag in ['end', 'b', 'o'] else: raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) - + elif encoding_type == 'bmes': """ 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 @@ -151,7 +151,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label class ConditionalRandomField(nn.Module): """ - 别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.crf.ConditionalRandomField` + 别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.ConditionalRandomField` 条件随机场。 提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 @@ -163,21 +163,21 @@ class ConditionalRandomField(nn.Module): allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 :param str initial_method: 初始化方法。见initial_parameter """ - + def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): - + super(ConditionalRandomField, self).__init__() - + self.include_start_end_trans = include_start_end_trans self.num_tags = num_tags - + # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) if self.include_start_end_trans: self.start_scores = nn.Parameter(torch.randn(num_tags)) self.end_scores = nn.Parameter(torch.randn(num_tags)) - + if allowed_transitions is None: constrain = torch.zeros(num_tags + 2, num_tags + 2) else: @@ -185,9 +185,9 @@ class ConditionalRandomField(nn.Module): for from_tag_id, to_tag_id in allowed_transitions: constrain[from_tag_id, to_tag_id] = 0 self._constrain = nn.Parameter(constrain, requires_grad=False) - + initial_parameter(self, initial_method) - + def _normalizer_likelihood(self, logits, mask): """Computes the (batch_size,) denominator term for the log-likelihood, which is the sum of the likelihoods across all possible state sequences. @@ -200,21 +200,21 @@ class ConditionalRandomField(nn.Module): alpha = logits[0] if self.include_start_end_trans: alpha = alpha + self.start_scores.view(1, -1) - + flip_mask = mask.eq(0) - + for i in range(1, seq_len): emit_score = logits[i].view(batch_size, 1, n_tags) trans_score = self.trans_m.view(1, n_tags, n_tags) tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) - + if self.include_start_end_trans: alpha = alpha + self.end_scores.view(1, -1) - + return torch.logsumexp(alpha, 1) - + def _gold_score(self, logits, tags, mask): """ Compute the score for the gold path. @@ -226,7 +226,7 @@ class ConditionalRandomField(nn.Module): seq_len, batch_size, _ = logits.size() batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) - + # trans_socre [L-1, B] mask = mask.byte() flip_mask = mask.eq(0) @@ -243,7 +243,7 @@ class ConditionalRandomField(nn.Module): score = score + st_scores + ed_scores # return [B,] return score - + def forward(self, feats, tags, mask): """ 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 @@ -258,9 +258,9 @@ class ConditionalRandomField(nn.Module): mask = mask.transpose(0, 1).float() all_path_score = self._normalizer_likelihood(feats, mask) gold_path_score = self._gold_score(feats, tags, mask) - + return all_path_score - gold_path_score - + def viterbi_decode(self, logits, mask, unpad=False): """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 @@ -277,7 +277,7 @@ class ConditionalRandomField(nn.Module): batch_size, seq_len, n_tags = logits.size() logits = logits.transpose(0, 1).data # L, B, H mask = mask.transpose(0, 1).data.byte() # L, B - + # dp vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vscore = logits[0] @@ -286,7 +286,7 @@ class ConditionalRandomField(nn.Module): if self.include_start_end_trans: transitions[n_tags, :n_tags] += self.start_scores.data transitions[:n_tags, n_tags + 1] += self.end_scores.data - + vscore += transitions[n_tags, :n_tags] trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data for i in range(1, seq_len): @@ -297,17 +297,17 @@ class ConditionalRandomField(nn.Module): vpath[i] = best_dst vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ vscore.masked_fill(mask[i].view(batch_size, 1), 0) - + if self.include_start_end_trans: vscore += transitions[:n_tags, n_tags + 1].view(1, -1) - + # backtrace batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) lens = (mask.long().sum(0) - 1) # idxes [L, B], batched idx from seq_len-1 to 0 idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len - + ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) ans_score, last_tags = vscore.max(1) ans[idxes[0], batch_idx] = last_tags diff --git a/fastNLP/modules/decoder/mlp.py b/fastNLP/modules/decoder/mlp.py index 418b3a77..9d9d80f2 100644 --- a/fastNLP/modules/decoder/mlp.py +++ b/fastNLP/modules/decoder/mlp.py @@ -10,7 +10,7 @@ from ..utils import initial_parameter class MLP(nn.Module): """ - 别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.mlp.MLP` + 别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.MLP` 多层感知器 @@ -40,7 +40,7 @@ class MLP(nn.Module): >>> print(x) >>> print(y) """ - + def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): super(MLP, self).__init__() self.hiddens = nn.ModuleList() @@ -51,9 +51,9 @@ class MLP(nn.Module): self.output = nn.Linear(size_layer[i - 1], size_layer[i]) else: self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i])) - + self.dropout = nn.Dropout(p=dropout) - + actives = { 'relu': nn.ReLU(), 'tanh': nn.Tanh(), @@ -82,7 +82,7 @@ class MLP(nn.Module): else: raise ValueError("should set activation correctly: {}".format(activation)) initial_parameter(self, initial_method) - + def forward(self, x): """ :param torch.Tensor x: MLP接受的输入 diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py index 249f3ff6..9e773336 100644 --- a/fastNLP/modules/decoder/utils.py +++ b/fastNLP/modules/decoder/utils.py @@ -6,7 +6,7 @@ import torch def viterbi_decode(logits, transitions, mask=None, unpad=False): r""" - 别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode` + 别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.viterbi_decode` 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 @@ -30,11 +30,11 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): mask = mask.transpose(0, 1).data.byte() # L, B else: mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) - + # dp vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vscore = logits[0] - + trans_score = transitions.view(1, n_tags, n_tags).data for i in range(1, seq_len): prev_score = vscore.view(batch_size, n_tags, 1) @@ -44,14 +44,14 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): vpath[i] = best_dst vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ vscore.masked_fill(mask[i].view(batch_size, 1), 0) - + # backtrace batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) lens = (mask.long().sum(0) - 1) # idxes [L, B], batched idx from seq_len-1 to 0 idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len - + ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) ans_score, last_tags = vscore.max(1) ans[idxes[0], batch_idx] = last_tags