From e0c86346619606f926e22c13140d56b36182817c Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Mon, 16 Sep 2019 16:16:23 +0800 Subject: [PATCH] add chinese char-level tokenizer --- fastNLP/io/pipe/matching.py | 4 ++-- fastNLP/io/pipe/utils.py | 21 +++++++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 7620a556..d6506f66 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -51,7 +51,7 @@ class MatchingBertPipe(Pipe): super().__init__() self.lower = bool(lower) - self.tokenizer = get_tokenizer(tokenizer=tokenizer) + self.tokenizer = get_tokenizer(tokenize_method=tokenizer) def _tokenize(self, data_bundle, field_names, new_field_names): """ @@ -191,7 +191,7 @@ class MatchingPipe(Pipe): super().__init__() self.lower = bool(lower) - self.tokenizer = get_tokenizer(tokenizer=tokenizer) + self.tokenizer = get_tokenizer(tokenize_method=tokenizer) def _tokenize(self, data_bundle, field_names, new_field_names): """ diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py index 92d61bfd..4925853f 100644 --- a/fastNLP/io/pipe/utils.py +++ b/fastNLP/io/pipe/utils.py @@ -65,27 +65,36 @@ def iob2bioes(tags: List[str]) -> List[str]: return new_tags -def get_tokenizer(tokenizer: str, lang='en'): +def get_tokenizer(tokenize_method: str, lang='en'): """ - :param str tokenizer: 获取tokenzier方法 + :param str tokenize_method: 获取tokenzier方法 :param str lang: 语言,当前仅支持en :return: 返回tokenize函数 """ - if tokenizer == 'spacy': + tokenizer_dict = { + 'spacy': None, + 'raw': _raw_split, + 'cn-char': _cn_char_split, + } + if tokenize_method == 'spacy': import spacy spacy.prefer_gpu() if lang != 'en': raise RuntimeError("Spacy only supports en right right.") en = spacy.load(lang) tokenizer = lambda x: [w.text for w in en.tokenizer(x)] - elif tokenizer == 'raw': - tokenizer = _raw_split + elif tokenize_method in tokenizer_dict: + tokenizer = tokenizer_dict[tokenize_method] else: - raise RuntimeError("Only support `spacy`, `raw` tokenizer.") + raise RuntimeError(f"Only support {tokenizer_dict.keys()} tokenizer.") return tokenizer +def _cn_char_split(sent): + return [chars for chars in sent] + + def _raw_split(sent): return sent.split()