diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py index 6ea1ae0c..4ca0219c 100644 --- a/fastNLP/io/pipe/cws.py +++ b/fastNLP/io/pipe/cws.py @@ -1,10 +1,13 @@ +import re +from itertools import chain + from .pipe import Pipe +from .utils import _indexize from .. import DataBundle from ..loader import CWSLoader -from ... import Const -from itertools import chain -from .utils import _indexize -import re +from ...core.const import Const + + def _word_lens_to_bmes(word_lens): """ @@ -13,11 +16,11 @@ def _word_lens_to_bmes(word_lens): """ tags = [] for word_len in word_lens: - if word_len==1: + if word_len == 1: tags.append('S') else: tags.append('B') - tags.extend(['M']*(word_len-2)) + tags.extend(['M'] * (word_len - 2)) tags.append('E') return tags @@ -30,10 +33,10 @@ def _word_lens_to_segapp(word_lens): """ tags = [] for word_len in word_lens: - if word_len==1: + if word_len == 1: tags.append('SEG') else: - tags.extend(['APP']*(word_len-1)) + tags.extend(['APP'] * (word_len - 1)) tags.append('SEG') return tags @@ -97,13 +100,21 @@ def _digit_span_to_special_tag(span): else: return '' + def _find_and_replace_digit_spans(line): - # only consider words start with number, contains '.', characters. - # If ends with space, will be processed - # If ends with Chinese character, will be processed - # If ends with or contains english char, not handled. - # floats are replaced by - # otherwise unkdgt + """ + only consider words start with number, contains '.', characters. + + If ends with space, will be processed + + If ends with Chinese character, will be processed + + If ends with or contains english char, not handled. + + floats are replaced by + + otherwise unkdgt + """ new_line = '' pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])' prev_end = 0 @@ -136,17 +147,18 @@ class CWSPipe(Pipe): :param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] :param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] """ + def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): - if encoding_type=='bmes': + if encoding_type == 'bmes': self.word_lens_to_tags = _word_lens_to_bmes else: self.word_lens_to_tags = _word_lens_to_segapp - + self.dataset_name = dataset_name self.bigrams = bigrams self.trigrams = trigrams self.replace_num_alpha = replace_num_alpha - + def _tokenize(self, data_bundle): """ 将data_bundle中的'chars'列切分成一个一个的word. @@ -162,10 +174,10 @@ class CWSPipe(Pipe): char = [] subchar = [] for c in word: - if c=='<': + if c == '<': subchar.append(c) continue - if c=='>' and subchar[0]=='<': + if c == '>' and subchar[0] == '<': char.append(''.join(subchar)) subchar = [] if subchar: @@ -175,12 +187,12 @@ class CWSPipe(Pipe): char.extend(subchar) chars.append(char) return chars - + for name, dataset in data_bundle.datasets.items(): dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) return data_bundle - + def process(self, data_bundle: DataBundle) -> DataBundle: """ 可以处理的DataSet需要包含raw_words列 @@ -196,42 +208,43 @@ class CWSPipe(Pipe): :return: """ data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) - + if self.replace_num_alpha: data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) - + self._tokenize(data_bundle) - + for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(lambda chars:self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT, + dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT, new_field_name=Const.TARGET) - dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT, + dataset.apply_field(lambda chars: list(chain(*chars)), field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) input_field_names = [Const.CHAR_INPUT] if self.bigrams: for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+[''])], + dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], field_name=Const.CHAR_INPUT, new_field_name='bigrams') input_field_names.append('bigrams') if self.trigrams: for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(lambda chars: [c1+c2+c3 for c1, c2, c3 in zip(chars, chars[1:]+[''], chars[2:]+['']*2)], + dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], field_name=Const.CHAR_INPUT, new_field_name='trigrams') input_field_names.append('trigrams') - + _indexize(data_bundle, input_field_names, Const.TARGET) - + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET, Const.INPUT_LEN] for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.CHAR_INPUT) - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle - + def process_from_file(self, paths=None) -> DataBundle: """ @@ -239,8 +252,9 @@ class CWSPipe(Pipe): :return: """ if self.dataset_name is None and paths is None: - raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") + raise RuntimeError( + "You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") if self.dataset_name is not None and paths is not None: raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") data_bundle = CWSLoader(self.dataset_name).load(paths) - return self.process(data_bundle) \ No newline at end of file + return self.process(data_bundle)