|
|
@@ -7,7 +7,7 @@ import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
from ..utils import initial_parameter |
|
|
|
|
|
|
|
from ...core import Vocabulary |
|
|
|
|
|
|
|
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): |
|
|
|
""" |
|
|
@@ -15,7 +15,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) |
|
|
|
|
|
|
|
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 |
|
|
|
|
|
|
|
:param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 |
|
|
|
:param dict,Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 |
|
|
|
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 |
|
|
|
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 |
|
|
|
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; |
|
|
@@ -23,6 +23,8 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) |
|
|
|
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 |
|
|
|
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 |
|
|
|
""" |
|
|
|
if isinstance(id2target, Vocabulary): |
|
|
|
id2target = id2target.idx2word |
|
|
|
num_tags = len(id2target) |
|
|
|
start_idx = num_tags |
|
|
|
end_idx = num_tags + 1 |
|
|
|