From 3624f7dafddc23bfd60faeaeb4e8eefe541fa2eb Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 19 Aug 2019 23:35:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0conll2003Pipe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 2 +- fastNLP/core/metrics.py | 1 + fastNLP/io/pipe/conll.py | 92 ++++++++++++++++++- fastNLP/io/pipe/utils.py | 38 ++++---- .../chinese_ner/train_cn_ner.py | 2 +- 5 files changed, 113 insertions(+), 22 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 53767011..47d4174b 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -646,7 +646,7 @@ class EvaluateCallback(Callback): raise TypeError("data receives dict[DataSet] or DataSet object.") def on_train_begin(self): - if len(self.datasets) > 0and self.trainer.dev_data is None: + if len(self.datasets) > 0 and self.trainer.dev_data is None: raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.") if len(self.datasets) > 0: diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 8dd51eb6..ef6f8b69 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -358,6 +358,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) + 也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index fb599340..58fab281 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -5,8 +5,8 @@ from ...core.const import Const from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader from .utils import _indexize, _add_words_field from .utils import _add_chars_field -from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader - +from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader +from ...core.vocabulary import Vocabulary class _NERPipe(Pipe): """ @@ -78,7 +78,7 @@ class Conll2003NERPipe(_NERPipe): :header: "raw_words", "words", "target", "seq_len" "[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 - "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4]", 6 + "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4,...]", 6 "[...]", "[...]", "[...]", . raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 @@ -102,6 +102,90 @@ class Conll2003NERPipe(_NERPipe): return data_bundle +class Conll2003Pipe(Pipe): + def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, target_pad_val=0): + """ + 经过该Pipe后,DataSet中的内容如下 + + .. csv-table:: + :header: "raw_words", "words", "pos", "chunk", "ner", "seq_len" + + "[Nadim, Ladki]", "[2, 3]", "[0, 0]", "[1, 2]", "[1, 2]", 2 + "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", 6 + "[...]", "[...]", "[...]", "[...]", "[...]". + + 其中words, seq_len是input; pos, chunk, ner, seq_len是target + + :param str chunk_encoding_type: 支持bioes, bio。 + :param str ner_encoding_type: 支持bioes, bio。 + :param bool lower: 是否将words列小写化后再建立词表 + :param int target_pad_val: pos, ner, chunk列的padding值 + """ + if chunk_encoding_type == 'bio': + self.chunk_convert_tag = iob2 + else: + self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) + if ner_encoding_type == 'bio': + self.ner_convert_tag = iob2 + else: + self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) + self.lower = lower + self.target_pad_val = int(target_pad_val) + + def process(self, data_bundle)->DataBundle: + """ + 输入的DataSet应该类似于如下的形式 + + .. csv-table:: + :header: "raw_words", "pos", "chunk", "ner" + + "[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[NNP, NNP...]", "[B-NP, B-NP, ...]", "[B-LOC, B-LOC,...]" + "[...]", "[...]", "[...]", "[...]". + + :param data_bundle: + :return: 传入的DataBundle + """ + # 转换tag + for name, dataset in data_bundle.datasets.items(): + dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD]) + dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') + dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') + + _add_words_field(data_bundle, lower=self.lower) + + # index + _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner']) + # chunk中存在一些tag只在dev中出现,没在train中 + tgt_vocab = Vocabulary(unknown=None, padding=None) + tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='ner') + tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='ner') + data_bundle.set_vocab(tgt_vocab, 'ner') + + input_fields = [Const.INPUT, Const.INPUT_LEN] + target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] + + for name, dataset in data_bundle.datasets.items(): + dataset.set_pad_val('pos', self.target_pad_val) + dataset.set_pad_val('ner', self.target_pad_val) + dataset.set_pad_val('chunk', self.target_pad_val) + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(*input_fields) + data_bundle.set_target(*target_fields) + + return data_bundle + + def process_from_file(self, paths): + """ + + :param paths: + :return: + """ + data_bundle = ConllLoader(headers=['raw_words', 'pos', 'chunk', 'ner']).load(paths) + return self.process(data_bundle) + + class OntoNotesNERPipe(_NERPipe): """ 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 @@ -171,7 +255,7 @@ class _CNNERPipe(Pipe): _add_chars_field(data_bundle, lower=False) # index - _indexize(data_bundle, input_field_name=Const.CHAR_INPUT, target_field_name=Const.TARGET) + _indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN] diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py index 7d011446..8facd8d9 100644 --- a/fastNLP/io/pipe/utils.py +++ b/fastNLP/io/pipe/utils.py @@ -4,7 +4,8 @@ from ...core.const import Const def iob2(tags:List[str])->List[str]: """ - 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format + 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 + https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format :param tags: 需要转换的tags """ @@ -76,27 +77,32 @@ def _raw_split(sent): return sent.split() -def _indexize(data_bundle, input_field_name=Const.INPUT, target_field_name=Const.TARGET): +def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Const.TARGET): """ 在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。 :param data_bundle: - :param: str input_field_name: - :param: str target_field_name: 这一列的vocabulary没有unknown和padding + :param: str,list input_field_names: + :param: str,list target_field_names: 这一列的vocabulary没有unknown和padding :return: """ - src_vocab = Vocabulary() - src_vocab.from_dataset(data_bundle.datasets['train'], field_name=input_field_name, - no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if - name != 'train']) - src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) - - tgt_vocab = Vocabulary(unknown=None, padding=None) - tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name) - tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) - - data_bundle.set_vocab(src_vocab, input_field_name) - data_bundle.set_vocab(tgt_vocab, target_field_name) + if isinstance(input_field_names, str): + input_field_names = [input_field_names] + if isinstance(target_field_names, str): + target_field_names = [target_field_names] + for input_field_name in input_field_names: + src_vocab = Vocabulary() + src_vocab.from_dataset(data_bundle.datasets['train'], field_name=input_field_name, + no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if + name != 'train']) + src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) + data_bundle.set_vocab(src_vocab, input_field_name) + + for target_field_name in target_field_names: + tgt_vocab = Vocabulary(unknown=None, padding=None) + tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name) + tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) + data_bundle.set_vocab(tgt_vocab, target_field_name) return data_bundle diff --git a/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py b/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py index 1005ea23..58b32265 100644 --- a/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py +++ b/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py @@ -47,7 +47,7 @@ class ChineseNERPipe(Pipe): _add_chars_field(data_bundle, lower=False) # index - _indexize(data_bundle, input_field_name=C.CHAR_INPUT, target_field_name=C.TARGET) + _indexize(data_bundle, input_field_names=C.CHAR_INPUT, target_field_names=C.TARGET) for name, dataset in data_bundle.datasets.items(): dataset.set_pad_val(C.TARGET, self.target_pad_val)