Browse Source

fix some importing bugs

tags/v0.4.10
ChenXin 5 years ago
parent
commit
9e16791c53
1 changed files with 49 additions and 35 deletions
  1. +49
    -35
      fastNLP/io/pipe/cws.py

+ 49
- 35
fastNLP/io/pipe/cws.py View File

@@ -1,10 +1,13 @@
import re
from itertools import chain

from .pipe import Pipe
from .utils import _indexize
from .. import DataBundle
from ..loader import CWSLoader
from ... import Const
from itertools import chain
from .utils import _indexize
import re
from ...core.const import Const


def _word_lens_to_bmes(word_lens):
"""

@@ -13,11 +16,11 @@ def _word_lens_to_bmes(word_lens):
"""
tags = []
for word_len in word_lens:
if word_len==1:
if word_len == 1:
tags.append('S')
else:
tags.append('B')
tags.extend(['M']*(word_len-2))
tags.extend(['M'] * (word_len - 2))
tags.append('E')
return tags

@@ -30,10 +33,10 @@ def _word_lens_to_segapp(word_lens):
"""
tags = []
for word_len in word_lens:
if word_len==1:
if word_len == 1:
tags.append('SEG')
else:
tags.extend(['APP']*(word_len-1))
tags.extend(['APP'] * (word_len - 1))
tags.append('SEG')
return tags

@@ -97,13 +100,21 @@ def _digit_span_to_special_tag(span):
else:
return '<NUM>'


def _find_and_replace_digit_spans(line):
# only consider words start with number, contains '.', characters.
# If ends with space, will be processed
# If ends with Chinese character, will be processed
# If ends with or contains english char, not handled.
# floats are replaced by <DEC>
# otherwise unkdgt
"""
only consider words start with number, contains '.', characters.
If ends with space, will be processed
If ends with Chinese character, will be processed
If ends with or contains english char, not handled.
floats are replaced by <DEC>
otherwise unkdgt
"""
new_line = ''
pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])'
prev_end = 0
@@ -136,17 +147,18 @@ class CWSPipe(Pipe):
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
"""
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False):
if encoding_type=='bmes':
if encoding_type == 'bmes':
self.word_lens_to_tags = _word_lens_to_bmes
else:
self.word_lens_to_tags = _word_lens_to_segapp
self.dataset_name = dataset_name
self.bigrams = bigrams
self.trigrams = trigrams
self.replace_num_alpha = replace_num_alpha
def _tokenize(self, data_bundle):
"""
将data_bundle中的'chars'列切分成一个一个的word.
@@ -162,10 +174,10 @@ class CWSPipe(Pipe):
char = []
subchar = []
for c in word:
if c=='<':
if c == '<':
subchar.append(c)
continue
if c=='>' and subchar[0]=='<':
if c == '>' and subchar[0] == '<':
char.append(''.join(subchar))
subchar = []
if subchar:
@@ -175,12 +187,12 @@ class CWSPipe(Pipe):
char.extend(subchar)
chars.append(char)
return chars
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT,
new_field_name=Const.CHAR_INPUT)
return data_bundle
def process(self, data_bundle: DataBundle) -> DataBundle:
"""
可以处理的DataSet需要包含raw_words列
@@ -196,42 +208,43 @@ class CWSPipe(Pipe):
:return:
"""
data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT)
if self.replace_num_alpha:
data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
self._tokenize(data_bundle)
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars:self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT,
dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT,
new_field_name=Const.TARGET)
dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT,
dataset.apply_field(lambda chars: list(chain(*chars)), field_name=Const.CHAR_INPUT,
new_field_name=Const.CHAR_INPUT)
input_field_names = [Const.CHAR_INPUT]
if self.bigrams:
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])],
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
field_name=Const.CHAR_INPUT, new_field_name='bigrams')
input_field_names.append('bigrams')
if self.trigrams:
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: [c1+c2+c3 for c1, c2, c3 in zip(chars, chars[1:]+['<eos>'], chars[2:]+['<eos>']*2)],
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')
_indexize(data_bundle, input_field_names, Const.TARGET)
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT)
data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)
return data_bundle
def process_from_file(self, paths=None) -> DataBundle:
"""

@@ -239,8 +252,9 @@ class CWSPipe(Pipe):
:return:
"""
if self.dataset_name is None and paths is None:
raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.")
raise RuntimeError(
"You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.")
if self.dataset_name is not None and paths is not None:
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously")
data_bundle = CWSLoader(self.dataset_name).load(paths)
return self.process(data_bundle)
return self.process(data_bundle)

Loading…
Cancel
Save