@@ -296,6 +296,8 @@ class AccuracyMetric(MetricBase): | |||
def bmes_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 | |||
返回[('', (0, 1)), ('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4)), ('', (4, 5)), ('', (5, 6))] | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
@@ -315,13 +317,45 @@ def bmes_tag_to_spans(tags, ignore_labels=None): | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1])) | |||
return [(span[0], (span[1][0], span[1][1]+1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def bmeso_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4))] | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bmes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bmes_tag, label = tag[:1], tag[2:] | |||
if bmes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bmes_tag == 'o': | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1]+1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def bio_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (特别注意这是左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
@@ -343,7 +377,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bio_tag = bio_tag | |||
return [(span[0], (span[1][0], span[1][1])) | |||
return [(span[0], (span[1][0], span[1][1]+1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
@@ -390,8 +424,7 @@ class SpanFPreRecMetric(MetricBase): | |||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
encoding_type = encoding_type.lower() | |||
if encoding_type not in ('bio', 'bmes'): | |||
raise ValueError("Only support 'bio' or 'bmes' type.") | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if f_type not in ('micro', 'macro'): | |||
@@ -402,6 +435,11 @@ class SpanFPreRecMetric(MetricBase): | |||
self.tag_to_span_func = bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
self.tag_to_span_func = bio_tag_to_spans | |||
elif self.encoding_type == 'bmeso': | |||
self.tag_to_span_func = bmeso_tag_to_spans | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
@@ -44,10 +44,14 @@ class Vocabulary(object): | |||
:param int max_size: set the max number of words in Vocabulary. Default: None | |||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | |||
:param padding: str, padding的字符,默认为<pad>。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立 | |||
Vocabulary的情况。 | |||
:param unknown: str, unknown的字符,默认为<unk>。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立 | |||
Vocabulary的情况。 | |||
""" | |||
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||
self.max_size = max_size | |||
self.min_freq = min_freq | |||
self.word_count = Counter() | |||
@@ -97,9 +101,9 @@ class Vocabulary(object): | |||
""" | |||
self.word2idx = {} | |||
if self.padding is not None: | |||
self.word2idx[self.padding] = 0 | |||
self.word2idx[self.padding] = len(self.word2idx) | |||
if self.unknown is not None: | |||
self.word2idx[self.unknown] = 1 | |||
self.word2idx[self.unknown] = len(self.word2idx) | |||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
words = self.word_count.most_common(max_size) | |||
@@ -877,6 +877,14 @@ class ConllPOSReader(object): | |||
class ConllxDataLoader(object): | |||
def load(self, path): | |||
""" | |||
:param path: str,存储数据的路径 | |||
:return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) | |||
类似于拥有以下结构, 一行为一个instance(sample) | |||
words pos_tags heads labels | |||
['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
@@ -25,7 +25,7 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||
:param encoding_type: str, 支持"bio", "bmes"。 | |||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
@@ -62,7 +62,7 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
""" | |||
:param encoding_type: str, 支持"BIO", "BMES"。 | |||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param from_label: str, 比如"PER", "LOC"等label | |||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
@@ -127,6 +127,18 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) | |||
return to_tag in ['b', 's', 'end'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||
elif encoding_type == 'bmeso': | |||
if from_tag == 'start': | |||
return to_tag in ['b', 's', 'o'] | |||
elif from_tag == 'b': | |||
return to_tag in ['m', 'e'] and from_label==to_label | |||
elif from_tag == 'm': | |||
return to_tag in ['m', 'e'] and from_label==to_label | |||
elif from_tag in ['e', 's', 'o']: | |||
return to_tag in ['b', 's', 'end', 'o'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
else: | |||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||