@@ -296,6 +296,8 @@ class AccuracyMetric(MetricBase): | |||||
def bmes_tag_to_spans(tags, ignore_labels=None): | def bmes_tag_to_spans(tags, ignore_labels=None): | ||||
""" | """ | ||||
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 | |||||
返回[('', (0, 1)), ('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4)), ('', (4, 5)), ('', (5, 6))] | |||||
:param tags: List[str], | :param tags: List[str], | ||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | :param ignore_labels: List[str], 在该list中的label将被忽略 | ||||
@@ -315,13 +317,45 @@ def bmes_tag_to_spans(tags, ignore_labels=None): | |||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bmes_tag = bmes_tag | prev_bmes_tag = bmes_tag | ||||
return [(span[0], (span[1][0], span[1][1])) | |||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
def bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
""" | |||||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4))] | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bmes_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bmes_tag, label = tag[:1], tag[2:] | |||||
if bmes_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
elif bmes_tag == 'o': | |||||
pass | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bmes_tag = bmes_tag | |||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | if span[0] not in ignore_labels | ||||
] | ] | ||||
def bio_tag_to_spans(tags, ignore_labels=None): | def bio_tag_to_spans(tags, ignore_labels=None): | ||||
""" | """ | ||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||||
返回[('singer', (1, 4))] (特别注意这是左闭右开区间) | |||||
:param tags: List[str], | :param tags: List[str], | ||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | :param ignore_labels: List[str], 在该list中的label将被忽略 | ||||
@@ -343,7 +377,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bio_tag = bio_tag | prev_bio_tag = bio_tag | ||||
return [(span[0], (span[1][0], span[1][1])) | |||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | if span[0] not in ignore_labels | ||||
] | ] | ||||
@@ -390,8 +424,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | ||||
""" | """ | ||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
if encoding_type not in ('bio', 'bmes'): | |||||
raise ValueError("Only support 'bio' or 'bmes' type.") | |||||
if not isinstance(tag_vocab, Vocabulary): | if not isinstance(tag_vocab, Vocabulary): | ||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
@@ -402,6 +435,11 @@ class SpanFPreRecMetric(MetricBase): | |||||
self.tag_to_span_func = bmes_tag_to_spans | self.tag_to_span_func = bmes_tag_to_spans | ||||
elif self.encoding_type == 'bio': | elif self.encoding_type == 'bio': | ||||
self.tag_to_span_func = bio_tag_to_spans | self.tag_to_span_func = bio_tag_to_spans | ||||
elif self.encoding_type == 'bmeso': | |||||
self.tag_to_span_func = bmeso_tag_to_spans | |||||
else: | |||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||||
self.ignore_labels = ignore_labels | self.ignore_labels = ignore_labels | ||||
self.f_type = f_type | self.f_type = f_type | ||||
self.beta = beta | self.beta = beta | ||||
@@ -44,10 +44,14 @@ class Vocabulary(object): | |||||
:param int max_size: set the max number of words in Vocabulary. Default: None | :param int max_size: set the max number of words in Vocabulary. Default: None | ||||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | ||||
:param padding: str, padding的字符,默认为<pad>。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立 | |||||
Vocabulary的情况。 | |||||
:param unknown: str, unknown的字符,默认为<unk>。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立 | |||||
Vocabulary的情况。 | |||||
""" | """ | ||||
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||||
self.max_size = max_size | self.max_size = max_size | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.word_count = Counter() | self.word_count = Counter() | ||||
@@ -97,9 +101,9 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word2idx = {} | self.word2idx = {} | ||||
if self.padding is not None: | if self.padding is not None: | ||||
self.word2idx[self.padding] = 0 | |||||
self.word2idx[self.padding] = len(self.word2idx) | |||||
if self.unknown is not None: | if self.unknown is not None: | ||||
self.word2idx[self.unknown] = 1 | |||||
self.word2idx[self.unknown] = len(self.word2idx) | |||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | ||||
words = self.word_count.most_common(max_size) | words = self.word_count.most_common(max_size) | ||||
@@ -877,6 +877,14 @@ class ConllPOSReader(object): | |||||
class ConllxDataLoader(object): | class ConllxDataLoader(object): | ||||
def load(self, path): | def load(self, path): | ||||
""" | |||||
:param path: str,存储数据的路径 | |||||
:return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) | |||||
类似于拥有以下结构, 一行为一个instance(sample) | |||||
words pos_tags heads labels | |||||
['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] | |||||
""" | |||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | sample = [] | ||||
@@ -25,7 +25,7 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | ||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | ||||
:param encoding_type: str, 支持"bio", "bmes"。 | |||||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||||
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | :return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | ||||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | ||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | start_idx=len(id2label), end_idx=len(id2label)+1。 | ||||
@@ -62,7 +62,7 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | ||||
""" | """ | ||||
:param encoding_type: str, 支持"BIO", "BMES"。 | |||||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
:param from_label: str, 比如"PER", "LOC"等label | :param from_label: str, 比如"PER", "LOC"等label | ||||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
@@ -127,6 +127,18 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) | |||||
return to_tag in ['b', 's', 'end'] | return to_tag in ['b', 's', 'end'] | ||||
else: | else: | ||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | ||||
elif encoding_type == 'bmeso': | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's', 'o'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
elif from_tag == 'm': | |||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
elif from_tag in ['e', 's', 'o']: | |||||
return to_tag in ['b', 's', 'end', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||||
else: | else: | ||||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | ||||