Browse Source

update code comments in CRF

tags/v0.3.0^2
FengZiYjun 5 years ago
parent
commit
7eb02f1762
2 changed files with 18 additions and 14 deletions
  1. +4
    -2
      fastNLP/modules/aggregator/attention.py
  2. +14
    -12
      fastNLP/modules/decoder/CRF.py

+ 4
- 2
fastNLP/modules/aggregator/attention.py View File

@@ -1,7 +1,9 @@
import math

import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import math
from torch import nn

from fastNLP.modules.utils import mask_softmax from fastNLP.modules.utils import mask_softmax






+ 14
- 12
fastNLP/modules/decoder/CRF.py View File

@@ -19,13 +19,14 @@ def seq_len_to_byte_mask(seq_lens):
mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1)) mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1))
return mask return mask



def allowed_transitions(id2label, encoding_type='bio'): def allowed_transitions(id2label, encoding_type='bio'):
""" """


:param id2label: dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。
:param encoding_type: str, 支持"bio", "bmes"。 :param encoding_type: str, 支持"bio", "bmes"。
:return:List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx).
start_idx=len(id2label), end_idx=len(id2label)+1。 start_idx=len(id2label), end_idx=len(id2label)+1。
""" """
@@ -57,6 +58,7 @@ def allowed_transitions(id2label, encoding_type='bio'):
allowed_trans.append((from_id, to_id)) allowed_trans.append((from_id, to_id))
return allowed_trans return allowed_trans



def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
""" """


@@ -130,16 +132,16 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label)




class ConditionalRandomField(nn.Module): class ConditionalRandomField(nn.Module):
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None):
"""
"""


:param num_tags: int, 标签的数量。
:param include_start_end_trans: bool, 是否包含起始tag
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]]. 允许的跃迁,可以通过allowed_transitions()得到。
如果为None,则所有跃迁均为合法
:param initial_method:
"""
:param int num_tags: 标签的数量。
:param bool include_start_end_trans: 是否包含起始tag
:param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。
如果为None,则所有跃迁均为合法
:param str initial_method:
"""


def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None):
super(ConditionalRandomField, self).__init__() super(ConditionalRandomField, self).__init__()


self.include_start_end_trans = include_start_end_trans self.include_start_end_trans = include_start_end_trans
@@ -235,8 +237,8 @@ class ConditionalRandomField(nn.Module):
return all_path_score - gold_path_score return all_path_score - gold_path_score


def viterbi_decode(self, data, mask, get_score=False, unpad=False): def viterbi_decode(self, data, mask, get_score=False, unpad=False):
"""
Given a feats matrix, return best decode path and best score.
"""Given a feats matrix, return best decode path and best score.
:param data:FloatTensor, batch_size x max_len x num_tags :param data:FloatTensor, batch_size x max_len x num_tags
:param mask:ByteTensor batch_size x max_len :param mask:ByteTensor batch_size x max_len
:param get_score: bool, whether to output the decode score. :param get_score: bool, whether to output the decode score.


Loading…
Cancel
Save