From 4c457e99248acf5e0d9384013a480d0632bf9877 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 26 Oct 2020 13:22:47 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=B0=86=20nltk=20=E4=BB=8E=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E4=B8=AD=E5=88=A0=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/pipe/classification.py | 124 +++++++++++++++--------------- requirements.txt | 1 - 2 files changed, 64 insertions(+), 61 deletions(-) diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index c59ffe5d..9475a092 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -17,7 +17,11 @@ __all__ = [ import re import warnings -from nltk import Tree +try: + from nltk import Tree +except: + # only nltk in some versions can run + pass from .pipe import Pipe from .utils import get_tokenizer, _indexize, _add_words_field, _add_chars_field, _granularize @@ -32,12 +36,12 @@ from ...core.instance import Instance class CLSBasePipe(Pipe): - - def __init__(self, lower: bool=False, tokenizer: str='spacy', lang='en'): + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', lang='en'): super().__init__() self.lower = lower self.tokenizer = get_tokenizer(tokenizer, lang=lang) - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): r""" 将DataBundle中的数据进行tokenize @@ -50,9 +54,9 @@ class CLSBasePipe(Pipe): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) - + return data_bundle - + def process(self, data_bundle: DataBundle): r""" 传入的DataSet应该具备如下的结构 @@ -73,15 +77,15 @@ class CLSBasePipe(Pipe): data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) # 建立词表并index data_bundle = _indexize(data_bundle=data_bundle) - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) - + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_target(Const.TARGET) - + return data_bundle - + def process_from_file(self, paths) -> DataBundle: r""" 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` @@ -151,7 +155,7 @@ class YelpFullPipe(CLSBasePipe): """ if self.tag_map is not None: data_bundle = _granularize(data_bundle, self.tag_map) - + data_bundle = super().process(data_bundle) return data_bundle @@ -231,7 +235,7 @@ class AGsNewsPipe(CLSBasePipe): +-------------+-----------+--------+-------+---------+ """ - + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): r""" @@ -239,7 +243,7 @@ class AGsNewsPipe(CLSBasePipe): :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ super().__init__(lower=lower, tokenizer=tokenizer, lang='en') - + def process_from_file(self, paths=None): r""" :param str paths: @@ -272,7 +276,7 @@ class DBPediaPipe(CLSBasePipe): +-------------+-----------+--------+-------+---------+ """ - + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): r""" @@ -280,7 +284,7 @@ class DBPediaPipe(CLSBasePipe): :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ super().__init__(lower=lower, tokenizer=tokenizer, lang='en') - + def process_from_file(self, paths=None): r""" :param str paths: @@ -369,7 +373,7 @@ class SSTPipe(CLSBasePipe): instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) ds.append(instance) data_bundle.set_dataset(ds, name) - + # 根据granularity设置tag data_bundle = _granularize(data_bundle, tag_map=self.tag_map) @@ -525,6 +529,7 @@ class ChnSentiCorpPipe(Pipe): +-------------+-----------+--------+-------+---------+ """ + def __init__(self, bigrams=False, trigrams=False): r""" @@ -536,10 +541,10 @@ class ChnSentiCorpPipe(Pipe): data_bundle.get_vocab('trigrams')获取. """ super().__init__() - + self.bigrams = bigrams self.trigrams = trigrams - + def _tokenize(self, data_bundle): r""" 将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 @@ -549,8 +554,8 @@ class ChnSentiCorpPipe(Pipe): """ data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) return data_bundle - - def process(self, data_bundle:DataBundle): + + def process(self, data_bundle: DataBundle): r""" 可以处理的DataSet应该具备以下的field @@ -565,9 +570,9 @@ class ChnSentiCorpPipe(Pipe): :return: """ _add_chars_field(data_bundle, lower=False) - + data_bundle = self._tokenize(data_bundle) - + input_field_names = [Const.CHAR_INPUT] if self.bigrams: for name, dataset in data_bundle.iter_datasets(): @@ -580,21 +585,21 @@ class ChnSentiCorpPipe(Pipe): zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], field_name=Const.CHAR_INPUT, new_field_name='trigrams') input_field_names.append('trigrams') - + # index _indexize(data_bundle, input_field_names, Const.TARGET) - + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET] - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.CHAR_INPUT) - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle - + def process_from_file(self, paths=None): r""" @@ -604,7 +609,7 @@ class ChnSentiCorpPipe(Pipe): # 读取数据 data_bundle = ChnSentiCorpLoader().load(paths) data_bundle = self.process(data_bundle) - + return data_bundle @@ -637,26 +642,26 @@ class THUCNewsPipe(CLSBasePipe): 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 data_bundle.get_vocab('trigrams')获取. """ - + def __init__(self, bigrams=False, trigrams=False): super().__init__() - + self.bigrams = bigrams self.trigrams = trigrams - + def _chracter_split(self, sent): return list(sent) # return [w for w in sent] - + def _raw_split(self, sent): return sent.split() - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) return data_bundle - + def process(self, data_bundle: DataBundle): r""" 可处理的DataSet应具备如下的field @@ -673,14 +678,14 @@ class THUCNewsPipe(CLSBasePipe): # 根据granularity设置tag tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9} data_bundle = _granularize(data_bundle=data_bundle, tag_map=tag_map) - + # clean,lower - + # CWS(tokenize) data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') - + input_field_names = [Const.CHAR_INPUT] - + # n-grams if self.bigrams: for name, dataset in data_bundle.iter_datasets(): @@ -693,22 +698,22 @@ class THUCNewsPipe(CLSBasePipe): zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], field_name=Const.CHAR_INPUT, new_field_name='trigrams') input_field_names.append('trigrams') - + # index data_bundle = _indexize(data_bundle=data_bundle, input_field_names=Const.CHAR_INPUT) - + # add length for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN) - + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET] - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle - + def process_from_file(self, paths=None): r""" :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 @@ -749,22 +754,22 @@ class WeiboSenti100kPipe(CLSBasePipe): 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 data_bundle.get_vocab('trigrams')获取. """ - + def __init__(self, bigrams=False, trigrams=False): super().__init__() - + self.bigrams = bigrams self.trigrams = trigrams - + def _chracter_split(self, sent): return list(sent) - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) return data_bundle - + def process(self, data_bundle: DataBundle): r""" 可处理的DataSet应具备以下的field @@ -779,12 +784,12 @@ class WeiboSenti100kPipe(CLSBasePipe): :return: """ # clean,lower - + # CWS(tokenize) data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') - + input_field_names = [Const.CHAR_INPUT] - + # n-grams if self.bigrams: for name, dataset in data_bundle.iter_datasets(): @@ -797,22 +802,22 @@ class WeiboSenti100kPipe(CLSBasePipe): zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], field_name=Const.CHAR_INPUT, new_field_name='trigrams') input_field_names.append('trigrams') - + # index data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') - + # add length for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN) - + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET] - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle - + def process_from_file(self, paths=None): r""" :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 @@ -822,4 +827,3 @@ class WeiboSenti100kPipe(CLSBasePipe): data_bundle = data_loader.load(paths) data_bundle = self.process(data_bundle) return data_bundle - diff --git a/requirements.txt b/requirements.txt index 242301be..81fb307c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ numpy>=1.14.2 torch>=1.0.0 tqdm>=4.28.1 -nltk>=3.4.1 prettytable>=0.7.2 requests spacy From ca25baf6b9007f6347fe7969e723537115fc64f9 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 26 Oct 2020 13:53:46 +0800 Subject: [PATCH 2/2] =?UTF-8?q?1.=20SKIP=20test=5Fprocess=5Ffrom=5Ffile=20?= =?UTF-8?q?2.=20doc=5Futils=20=E5=A2=9E=E5=8A=A0=E4=BA=86=20=5F=5Fall=5F?= =?UTF-8?q?=5F=20=E7=9A=84=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/doc_utils.py | 10 +++++++--- test/io/pipe/test_classification.py | 8 ++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/fastNLP/doc_utils.py b/fastNLP/doc_utils.py index 3f7889e4..119db776 100644 --- a/fastNLP/doc_utils.py +++ b/fastNLP/doc_utils.py @@ -23,7 +23,9 @@ def doc_process(m): while 1: defined_m = sys.modules[module_name] try: - if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: + if not hasattr(defined_m, "__all__"): + print("Warning: Module {} lacks `__all__`".format(module_name)) + elif "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: obj.__doc__ = r"别名 :class:`" + m.__name__ + "." + name + "`" \ + " :class:`" + module_name + "." + name + "`\n" + obj.__doc__ break @@ -34,7 +36,7 @@ def doc_process(m): except: print("Warning: Module {} lacks `__doc__`".format(module_name)) break - + # 识别并标注基类,只有基类也在 fastNLP 中定义才显示 if inspect.isclass(obj): @@ -45,7 +47,9 @@ def doc_process(m): for i in range(len(parts) - 1): defined_m = sys.modules[module_name] try: - if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: + if not hasattr(defined_m, "__all__"): + print("Warning: Module {} lacks `__all__`".format(module_name)) + elif "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: obj.__doc__ = r"基类 :class:`" + defined_m.__name__ + "." + base.__name__ + "` \n\n" + obj.__doc__ break module_name += "." + parts[i + 1] diff --git a/test/io/pipe/test_classification.py b/test/io/pipe/test_classification.py index c6bd5444..8ebdb2df 100644 --- a/test/io/pipe/test_classification.py +++ b/test/io/pipe/test_classification.py @@ -10,7 +10,7 @@ from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, Weibo @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") class TestClassificationPipe(unittest.TestCase): def test_process_from_file(self): - for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: + for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: with self.subTest(pipe=pipe): print(pipe) data_bundle = pipe(tokenizer='raw').process_from_file() @@ -33,6 +33,7 @@ class TestCNClassificationPipe(unittest.TestCase): print(data_bundle) +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") class TestRunClassificationPipe(unittest.TestCase): def test_process_from_file(self): data_set_dict = { @@ -79,15 +80,14 @@ class TestRunClassificationPipe(unittest.TestCase): data_bundle = pipe(tokenizer='raw').process_from_file(path) else: data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) - + self.assertTrue(isinstance(data_bundle, DataBundle)) self.assertEqual(len(data_set), data_bundle.num_dataset) for name, dataset in data_bundle.iter_datasets(): self.assertTrue(name in data_set.keys()) self.assertEqual(data_set[name], len(dataset)) - + self.assertEqual(len(vocab), data_bundle.num_vocab) for name, vocabs in data_bundle.iter_vocabs(): self.assertTrue(name in vocab.keys()) self.assertEqual(vocab[name], len(vocabs)) -