|
|
@@ -1,10 +1,13 @@ |
|
|
|
import re |
|
|
|
from itertools import chain |
|
|
|
|
|
|
|
from .pipe import Pipe |
|
|
|
from .utils import _indexize |
|
|
|
from .. import DataBundle |
|
|
|
from ..loader import CWSLoader |
|
|
|
from ... import Const |
|
|
|
from itertools import chain |
|
|
|
from .utils import _indexize |
|
|
|
import re |
|
|
|
from ...core.const import Const |
|
|
|
|
|
|
|
|
|
|
|
def _word_lens_to_bmes(word_lens): |
|
|
|
""" |
|
|
|
|
|
|
@@ -13,11 +16,11 @@ def _word_lens_to_bmes(word_lens): |
|
|
|
""" |
|
|
|
tags = [] |
|
|
|
for word_len in word_lens: |
|
|
|
if word_len==1: |
|
|
|
if word_len == 1: |
|
|
|
tags.append('S') |
|
|
|
else: |
|
|
|
tags.append('B') |
|
|
|
tags.extend(['M']*(word_len-2)) |
|
|
|
tags.extend(['M'] * (word_len - 2)) |
|
|
|
tags.append('E') |
|
|
|
return tags |
|
|
|
|
|
|
@@ -30,10 +33,10 @@ def _word_lens_to_segapp(word_lens): |
|
|
|
""" |
|
|
|
tags = [] |
|
|
|
for word_len in word_lens: |
|
|
|
if word_len==1: |
|
|
|
if word_len == 1: |
|
|
|
tags.append('SEG') |
|
|
|
else: |
|
|
|
tags.extend(['APP']*(word_len-1)) |
|
|
|
tags.extend(['APP'] * (word_len - 1)) |
|
|
|
tags.append('SEG') |
|
|
|
return tags |
|
|
|
|
|
|
@@ -97,13 +100,21 @@ def _digit_span_to_special_tag(span): |
|
|
|
else: |
|
|
|
return '<NUM>' |
|
|
|
|
|
|
|
|
|
|
|
def _find_and_replace_digit_spans(line): |
|
|
|
# only consider words start with number, contains '.', characters. |
|
|
|
# If ends with space, will be processed |
|
|
|
# If ends with Chinese character, will be processed |
|
|
|
# If ends with or contains english char, not handled. |
|
|
|
# floats are replaced by <DEC> |
|
|
|
# otherwise unkdgt |
|
|
|
""" |
|
|
|
only consider words start with number, contains '.', characters. |
|
|
|
|
|
|
|
If ends with space, will be processed |
|
|
|
|
|
|
|
If ends with Chinese character, will be processed |
|
|
|
|
|
|
|
If ends with or contains english char, not handled. |
|
|
|
|
|
|
|
floats are replaced by <DEC> |
|
|
|
|
|
|
|
otherwise unkdgt |
|
|
|
""" |
|
|
|
new_line = '' |
|
|
|
pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])' |
|
|
|
prev_end = 0 |
|
|
@@ -136,17 +147,18 @@ class CWSPipe(Pipe): |
|
|
|
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] |
|
|
|
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): |
|
|
|
if encoding_type=='bmes': |
|
|
|
if encoding_type == 'bmes': |
|
|
|
self.word_lens_to_tags = _word_lens_to_bmes |
|
|
|
else: |
|
|
|
self.word_lens_to_tags = _word_lens_to_segapp |
|
|
|
|
|
|
|
|
|
|
|
self.dataset_name = dataset_name |
|
|
|
self.bigrams = bigrams |
|
|
|
self.trigrams = trigrams |
|
|
|
self.replace_num_alpha = replace_num_alpha |
|
|
|
|
|
|
|
|
|
|
|
def _tokenize(self, data_bundle): |
|
|
|
""" |
|
|
|
将data_bundle中的'chars'列切分成一个一个的word. |
|
|
@@ -162,10 +174,10 @@ class CWSPipe(Pipe): |
|
|
|
char = [] |
|
|
|
subchar = [] |
|
|
|
for c in word: |
|
|
|
if c=='<': |
|
|
|
if c == '<': |
|
|
|
subchar.append(c) |
|
|
|
continue |
|
|
|
if c=='>' and subchar[0]=='<': |
|
|
|
if c == '>' and subchar[0] == '<': |
|
|
|
char.append(''.join(subchar)) |
|
|
|
subchar = [] |
|
|
|
if subchar: |
|
|
@@ -175,12 +187,12 @@ class CWSPipe(Pipe): |
|
|
|
char.extend(subchar) |
|
|
|
chars.append(char) |
|
|
|
return chars |
|
|
|
|
|
|
|
|
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, |
|
|
|
new_field_name=Const.CHAR_INPUT) |
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process(self, data_bundle: DataBundle) -> DataBundle: |
|
|
|
""" |
|
|
|
可以处理的DataSet需要包含raw_words列 |
|
|
@@ -196,42 +208,43 @@ class CWSPipe(Pipe): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) |
|
|
|
|
|
|
|
|
|
|
|
if self.replace_num_alpha: |
|
|
|
data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) |
|
|
|
data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) |
|
|
|
|
|
|
|
|
|
|
|
self._tokenize(data_bundle) |
|
|
|
|
|
|
|
|
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.apply_field(lambda chars:self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT, |
|
|
|
dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT, |
|
|
|
new_field_name=Const.TARGET) |
|
|
|
dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT, |
|
|
|
dataset.apply_field(lambda chars: list(chain(*chars)), field_name=Const.CHAR_INPUT, |
|
|
|
new_field_name=Const.CHAR_INPUT) |
|
|
|
input_field_names = [Const.CHAR_INPUT] |
|
|
|
if self.bigrams: |
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])], |
|
|
|
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], |
|
|
|
field_name=Const.CHAR_INPUT, new_field_name='bigrams') |
|
|
|
input_field_names.append('bigrams') |
|
|
|
if self.trigrams: |
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.apply_field(lambda chars: [c1+c2+c3 for c1, c2, c3 in zip(chars, chars[1:]+['<eos>'], chars[2:]+['<eos>']*2)], |
|
|
|
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in |
|
|
|
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], |
|
|
|
field_name=Const.CHAR_INPUT, new_field_name='trigrams') |
|
|
|
input_field_names.append('trigrams') |
|
|
|
|
|
|
|
|
|
|
|
_indexize(data_bundle, input_field_names, Const.TARGET) |
|
|
|
|
|
|
|
|
|
|
|
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names |
|
|
|
target_fields = [Const.TARGET, Const.INPUT_LEN] |
|
|
|
for name, dataset in data_bundle.datasets.items(): |
|
|
|
dataset.add_seq_len(Const.CHAR_INPUT) |
|
|
|
|
|
|
|
|
|
|
|
data_bundle.set_input(*input_fields) |
|
|
|
data_bundle.set_target(*target_fields) |
|
|
|
|
|
|
|
|
|
|
|
return data_bundle |
|
|
|
|
|
|
|
|
|
|
|
def process_from_file(self, paths=None) -> DataBundle: |
|
|
|
""" |
|
|
|
|
|
|
@@ -239,8 +252,9 @@ class CWSPipe(Pipe): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if self.dataset_name is None and paths is None: |
|
|
|
raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") |
|
|
|
raise RuntimeError( |
|
|
|
"You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") |
|
|
|
if self.dataset_name is not None and paths is not None: |
|
|
|
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") |
|
|
|
data_bundle = CWSLoader(self.dataset_name).load(paths) |
|
|
|
return self.process(data_bundle) |
|
|
|
return self.process(data_bundle) |