Browse Source

1. CRF增加支持bmeso类型的tag 2. vocabulary中增加注释

tags/v0.4.10
yh 5 years ago
parent
commit
29eab18b78
4 changed files with 71 additions and 9 deletions
  1. +42
    -4
      fastNLP/core/metrics.py
  2. +7
    -3
      fastNLP/core/vocabulary.py
  3. +8
    -0
      fastNLP/io/dataset_loader.py
  4. +14
    -2
      fastNLP/modules/decoder/CRF.py

+ 42
- 4
fastNLP/core/metrics.py View File

@@ -296,6 +296,8 @@ class AccuracyMetric(MetricBase):


def bmes_tag_to_spans(tags, ignore_labels=None): def bmes_tag_to_spans(tags, ignore_labels=None):
""" """
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。
返回[('', (0, 1)), ('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4)), ('', (4, 5)), ('', (5, 6))]


:param tags: List[str], :param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略 :param ignore_labels: List[str], 在该list中的label将被忽略
@@ -315,13 +317,45 @@ def bmes_tag_to_spans(tags, ignore_labels=None):
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1]))
return [(span[0], (span[1][0], span[1][1]+1))
for span in spans
if span[0] not in ignore_labels
]

def bmeso_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
返回[('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4))]

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]:
spans[-1][1][1] = idx
elif bmes_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1]+1))
for span in spans for span in spans
if span[0] not in ignore_labels if span[0] not in ignore_labels
] ]


def bio_tag_to_spans(tags, ignore_labels=None): def bio_tag_to_spans(tags, ignore_labels=None):
""" """
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。
返回[('singer', (1, 4))] (特别注意这是左闭右开区间)


:param tags: List[str], :param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略 :param ignore_labels: List[str], 在该list中的label将被忽略
@@ -343,7 +377,7 @@ def bio_tag_to_spans(tags, ignore_labels=None):
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bio_tag = bio_tag prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1]))
return [(span[0], (span[1][0], span[1][1]+1))
for span in spans for span in spans
if span[0] not in ignore_labels if span[0] not in ignore_labels
] ]
@@ -390,8 +424,7 @@ class SpanFPreRecMetric(MetricBase):
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
""" """
encoding_type = encoding_type.lower() encoding_type = encoding_type.lower()
if encoding_type not in ('bio', 'bmes'):
raise ValueError("Only support 'bio' or 'bmes' type.")

if not isinstance(tag_vocab, Vocabulary): if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
if f_type not in ('micro', 'macro'): if f_type not in ('micro', 'macro'):
@@ -402,6 +435,11 @@ class SpanFPreRecMetric(MetricBase):
self.tag_to_span_func = bmes_tag_to_spans self.tag_to_span_func = bmes_tag_to_spans
elif self.encoding_type == 'bio': elif self.encoding_type == 'bio':
self.tag_to_span_func = bio_tag_to_spans self.tag_to_span_func = bio_tag_to_spans
elif self.encoding_type == 'bmeso':
self.tag_to_span_func = bmeso_tag_to_spans
else:
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.")

self.ignore_labels = ignore_labels self.ignore_labels = ignore_labels
self.f_type = f_type self.f_type = f_type
self.beta = beta self.beta = beta


+ 7
- 3
fastNLP/core/vocabulary.py View File

@@ -44,10 +44,14 @@ class Vocabulary(object):


:param int max_size: set the max number of words in Vocabulary. Default: None :param int max_size: set the max number of words in Vocabulary. Default: None
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None
:param padding: str, padding的字符,默认为<pad>。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立
Vocabulary的情况。
:param unknown: str, unknown的字符,默认为<unk>。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立
Vocabulary的情况。


""" """


def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'):
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'):
self.max_size = max_size self.max_size = max_size
self.min_freq = min_freq self.min_freq = min_freq
self.word_count = Counter() self.word_count = Counter()
@@ -97,9 +101,9 @@ class Vocabulary(object):
""" """
self.word2idx = {} self.word2idx = {}
if self.padding is not None: if self.padding is not None:
self.word2idx[self.padding] = 0
self.word2idx[self.padding] = len(self.word2idx)
if self.unknown is not None: if self.unknown is not None:
self.word2idx[self.unknown] = 1
self.word2idx[self.unknown] = len(self.word2idx)


max_size = min(self.max_size, len(self.word_count)) if self.max_size else None max_size = min(self.max_size, len(self.word_count)) if self.max_size else None
words = self.word_count.most_common(max_size) words = self.word_count.most_common(max_size)


+ 8
- 0
fastNLP/io/dataset_loader.py View File

@@ -877,6 +877,14 @@ class ConllPOSReader(object):


class ConllxDataLoader(object): class ConllxDataLoader(object):
def load(self, path): def load(self, path):
"""

:param path: str,存储数据的路径
:return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label)
类似于拥有以下结构, 一行为一个instance(sample)
words pos_tags heads labels
['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...]
"""
datalist = [] datalist = []
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
sample = [] sample = []


+ 14
- 2
fastNLP/modules/decoder/CRF.py View File

@@ -25,7 +25,7 @@ def allowed_transitions(id2label, encoding_type='bio'):


:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。
:param encoding_type: str, 支持"bio", "bmes"。
:param encoding_type: str, 支持"bio", "bmes", "bmeso"
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 :return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx).
start_idx=len(id2label), end_idx=len(id2label)+1。 start_idx=len(id2label), end_idx=len(id2label)+1。
@@ -62,7 +62,7 @@ def allowed_transitions(id2label, encoding_type='bio'):
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
""" """


:param encoding_type: str, 支持"BIO", "BMES"。
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag :param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param from_label: str, 比如"PER", "LOC"等label :param from_label: str, 比如"PER", "LOC"等label
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag :param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
@@ -127,6 +127,18 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label)
return to_tag in ['b', 's', 'end'] return to_tag in ['b', 's', 'end']
else: else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag))
elif encoding_type == 'bmeso':
if from_tag == 'start':
return to_tag in ['b', 's', 'o']
elif from_tag == 'b':
return to_tag in ['m', 'e'] and from_label==to_label
elif from_tag == 'm':
return to_tag in ['m', 'e'] and from_label==to_label
elif from_tag in ['e', 's', 'o']:
return to_tag in ['b', 's', 'end', 'o']
else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag))

else: else:
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type))




Loading…
Cancel
Save