From 7eb02f17621af790a6afb4bf0f94a7452923206f Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Fri, 11 Jan 2019 20:59:42 +0800 Subject: [PATCH] update code comments in CRF --- fastNLP/modules/aggregator/attention.py | 6 ++++-- fastNLP/modules/decoder/CRF.py | 26 +++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index a26b195d..3fea1b10 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -1,7 +1,9 @@ +import math + import torch -from torch import nn import torch.nn.functional as F -import math +from torch import nn + from fastNLP.modules.utils import mask_softmax diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index fe7a8465..d7db3bf9 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -19,13 +19,14 @@ def seq_len_to_byte_mask(seq_lens): mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1)) return mask + def allowed_transitions(id2label, encoding_type='bio'): """ - :param id2label: dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 + :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 :param encoding_type: str, 支持"bio", "bmes"。 - :return:List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 + :return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). start_idx=len(id2label), end_idx=len(id2label)+1。 """ @@ -57,6 +58,7 @@ def allowed_transitions(id2label, encoding_type='bio'): allowed_trans.append((from_id, to_id)) return allowed_trans + def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): """ @@ -130,16 +132,16 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) class ConditionalRandomField(nn.Module): - def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): - """ + """ - :param num_tags: int, 标签的数量。 - :param include_start_end_trans: bool, 是否包含起始tag - :param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]]. 允许的跃迁,可以通过allowed_transitions()得到。 - 如果为None,则所有跃迁均为合法 - :param initial_method: - """ + :param int num_tags: 标签的数量。 + :param bool include_start_end_trans: 是否包含起始tag + :param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。 + 如果为None,则所有跃迁均为合法 + :param str initial_method: + """ + def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): super(ConditionalRandomField, self).__init__() self.include_start_end_trans = include_start_end_trans @@ -235,8 +237,8 @@ class ConditionalRandomField(nn.Module): return all_path_score - gold_path_score def viterbi_decode(self, data, mask, get_score=False, unpad=False): - """ - Given a feats matrix, return best decode path and best score. + """Given a feats matrix, return best decode path and best score. + :param data:FloatTensor, batch_size x max_len x num_tags :param mask:ByteTensor batch_size x max_len :param get_score: bool, whether to output the decode score.