diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index 98363fce..a29de173 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -37,10 +37,11 @@ from fastNLP.core.log import logger class CLSBasePipe(Pipe): - def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en'): + def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en', num_proc=0): super().__init__() self.lower = lower self.tokenizer = get_tokenizer(tokenizer, lang=lang) + self.num_proc = num_proc def _tokenize(self, data_bundle, field_name='words', new_field_name=None): r""" @@ -53,7 +54,8 @@ class CLSBasePipe(Pipe): """ new_field_name = new_field_name or field_name for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) + dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, + num_proc=self.num_proc) return data_bundle @@ -117,7 +119,7 @@ class YelpFullPipe(CLSBasePipe): """ - def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy', num_proc=0): r""" :param bool lower: 是否对输入进行小写化。 @@ -125,7 +127,7 @@ class YelpFullPipe(CLSBasePipe): 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) assert granularity in (2, 3, 5), "granularity can only be 2,3,5." self.granularity = granularity @@ -191,13 +193,13 @@ class YelpPolarityPipe(CLSBasePipe): """ - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): r""" :param bool lower: 是否对输入进行小写化。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) def process_from_file(self, paths=None): r""" @@ -233,13 +235,13 @@ class AGsNewsPipe(CLSBasePipe): """ - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): r""" :param bool lower: 是否对输入进行小写化。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) def process_from_file(self, paths=None): r""" @@ -274,13 +276,13 @@ class DBPediaPipe(CLSBasePipe): """ - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): r""" :param bool lower: 是否对输入进行小写化。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) def process_from_file(self, paths=None): r""" @@ -315,7 +317,7 @@ class SSTPipe(CLSBasePipe): """ - def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): + def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy', num_proc=0): r""" :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` @@ -325,7 +327,7 @@ class SSTPipe(CLSBasePipe): 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - super().__init__(tokenizer=tokenizer, lang='en') + super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.subtree = subtree self.train_tree = train_subtree self.lower = lower @@ -407,13 +409,13 @@ class SST2Pipe(CLSBasePipe): """ - def __init__(self, lower=False, tokenizer='raw'): + def __init__(self, lower=False, tokenizer='raw', num_proc=0): r""" :param bool lower: 是否对输入进行小写化。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。 """ - super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) def process_from_file(self, paths=None): r""" @@ -452,13 +454,13 @@ class IMDBPipe(CLSBasePipe): """ - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): r""" :param bool lower: 是否将words列的数据小写。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ - super().__init__(tokenizer=tokenizer, lang='en') + super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process(self, data_bundle: DataBundle): @@ -483,7 +485,7 @@ class IMDBPipe(CLSBasePipe): return raw_words for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words') + dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words', num_proc=self.num_proc) data_bundle = super().process(data_bundle) @@ -527,7 +529,7 @@ class ChnSentiCorpPipe(Pipe): """ - def __init__(self, bigrams=False, trigrams=False): + def __init__(self, bigrams=False, trigrams=False, num_proc: int = 0): r""" :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 @@ -541,10 +543,11 @@ class ChnSentiCorpPipe(Pipe): self.bigrams = bigrams self.trigrams = trigrams + self.num_proc = num_proc def _tokenize(self, data_bundle): r""" - 将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 + 将 DataSet 中的"复旦大学"拆分为 ["复", "旦", "大", "学"] . 未来可以通过扩展这个函数实现分词。 :param data_bundle: :return: @@ -571,24 +574,26 @@ class ChnSentiCorpPipe(Pipe): data_bundle = self._tokenize(data_bundle) input_field_names = ['chars'] + + def bigrams(chars): + return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])] + + def trigrams(chars): + return [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)] + if self.bigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], - field_name='chars', new_field_name='bigrams') + dataset.apply_field(bigrams,field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) input_field_names.append('bigrams') if self.trigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in - zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], - field_name='chars', new_field_name='trigrams') + dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) input_field_names.append('trigrams') # index _indexize(data_bundle, input_field_names, 'target') - input_fields = ['target', 'seq_len'] + input_field_names - target_fields = ['target'] - for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len('chars') @@ -637,8 +642,8 @@ class THUCNewsPipe(CLSBasePipe): data_bundle.get_vocab('trigrams')获取. """ - def __init__(self, bigrams=False, trigrams=False): - super().__init__() + def __init__(self, bigrams=False, trigrams=False, num_proc=0): + super().__init__(num_proc=num_proc) self.bigrams = bigrams self.trigrams = trigrams @@ -653,7 +658,7 @@ class THUCNewsPipe(CLSBasePipe): def _tokenize(self, data_bundle, field_name='words', new_field_name=None): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) + dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) return data_bundle def process(self, data_bundle: DataBundle): @@ -680,17 +685,21 @@ class THUCNewsPipe(CLSBasePipe): input_field_names = ['chars'] + def bigrams(chars): + return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])] + + def trigrams(chars): + return [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)] + # n-grams if self.bigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], - field_name='chars', new_field_name='bigrams') + dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) input_field_names.append('bigrams') if self.trigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in - zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], - field_name='chars', new_field_name='trigrams') + dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) input_field_names.append('trigrams') # index @@ -700,9 +709,6 @@ class THUCNewsPipe(CLSBasePipe): for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(field_name='chars', new_field_name='seq_len') - input_fields = ['target', 'seq_len'] + input_field_names - target_fields = ['target'] - return data_bundle def process_from_file(self, paths=None): @@ -746,8 +752,8 @@ class WeiboSenti100kPipe(CLSBasePipe): data_bundle.get_vocab('trigrams')获取. """ - def __init__(self, bigrams=False, trigrams=False): - super().__init__() + def __init__(self, bigrams=False, trigrams=False, num_proc=0): + super().__init__(num_proc=num_proc) self.bigrams = bigrams self.trigrams = trigrams @@ -758,7 +764,8 @@ class WeiboSenti100kPipe(CLSBasePipe): def _tokenize(self, data_bundle, field_name='words', new_field_name=None): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) + dataset.apply_field(self._chracter_split, field_name=field_name, + new_field_name=new_field_name, num_proc=self.num_proc) return data_bundle def process(self, data_bundle: DataBundle): @@ -779,20 +786,19 @@ class WeiboSenti100kPipe(CLSBasePipe): # CWS(tokenize) data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') - input_field_names = ['chars'] + def bigrams(chars): + return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])] + def trigrams(chars): + return [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)] # n-grams if self.bigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], - field_name='chars', new_field_name='bigrams') - input_field_names.append('bigrams') + dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) if self.trigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in - zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], - field_name='chars', new_field_name='trigrams') - input_field_names.append('trigrams') + dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) # index data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') @@ -801,9 +807,6 @@ class WeiboSenti100kPipe(CLSBasePipe): for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(field_name='chars', new_field_name='seq_len') - input_fields = ['target', 'seq_len'] + input_field_names - target_fields = ['target'] - return data_bundle def process_from_file(self, paths=None): @@ -817,13 +820,13 @@ class WeiboSenti100kPipe(CLSBasePipe): return data_bundle class MRPipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): r""" :param bool lower: 是否将words列的数据小写。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ - super().__init__(tokenizer=tokenizer, lang='en') + super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): @@ -840,13 +843,13 @@ class MRPipe(CLSBasePipe): class R8Pipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc = 0): r""" :param bool lower: 是否将words列的数据小写。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ - super().__init__(tokenizer=tokenizer, lang='en') + super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): @@ -863,13 +866,13 @@ class R8Pipe(CLSBasePipe): class R52Pipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): r""" :param bool lower: 是否将words列的数据小写。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ - super().__init__(tokenizer=tokenizer, lang='en') + super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): @@ -886,13 +889,13 @@ class R52Pipe(CLSBasePipe): class OhsumedPipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): r""" :param bool lower: 是否将words列的数据小写。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ - super().__init__(tokenizer=tokenizer, lang='en') + super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): @@ -909,13 +912,13 @@ class OhsumedPipe(CLSBasePipe): class NG20Pipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): r""" :param bool lower: 是否将words列的数据小写。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ - super().__init__(tokenizer=tokenizer, lang='en') + super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index 43982363..efe05de0 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -30,7 +30,7 @@ class _NERPipe(Pipe): target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 """ - def __init__(self, encoding_type: str = 'bio', lower: bool = False): + def __init__(self, encoding_type: str = 'bio', lower: bool = False, num_proc=0): r""" :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 @@ -39,10 +39,14 @@ class _NERPipe(Pipe): if encoding_type == 'bio': self.convert_tag = iob2 elif encoding_type == 'bioes': - self.convert_tag = lambda words: iob2bioes(iob2(words)) + def func(words): + return iob2bioes(iob2(words)) + # self.convert_tag = lambda words: iob2bioes(iob2(words)) + self.convert_tag = func else: raise ValueError("encoding_type only supports `bio` and `bioes`.") self.lower = lower + self.num_proc = num_proc def process(self, data_bundle: DataBundle) -> DataBundle: r""" @@ -60,16 +64,13 @@ class _NERPipe(Pipe): """ # 转换tag for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') + dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target', num_proc=self.num_proc) _add_words_field(data_bundle, lower=self.lower) # index _indexize(data_bundle) - input_fields = ['target', 'words', 'seq_len'] - target_fields = ['target', 'seq_len'] - for name, dataset in data_bundle.iter_datasets(): dataset.add_seq_len('words') @@ -144,7 +145,7 @@ class Conll2003Pipe(Pipe): """ - def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): + def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, num_proc: int = 0): r""" :param str chunk_encoding_type: 支持bioes, bio。 @@ -154,16 +155,23 @@ class Conll2003Pipe(Pipe): if chunk_encoding_type == 'bio': self.chunk_convert_tag = iob2 elif chunk_encoding_type == 'bioes': - self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) + def func1(tags): + return iob2bioes(iob2(tags)) + # self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) + self.chunk_convert_tag = func1 else: raise ValueError("chunk_encoding_type only supports `bio` and `bioes`.") if ner_encoding_type == 'bio': self.ner_convert_tag = iob2 elif ner_encoding_type == 'bioes': - self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) + def func2(tags): + return iob2bioes(iob2(tags)) + # self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) + self.ner_convert_tag = func2 else: raise ValueError("ner_encoding_type only supports `bio` and `bioes`.") self.lower = lower + self.num_proc = num_proc def process(self, data_bundle) -> DataBundle: r""" @@ -182,8 +190,8 @@ class Conll2003Pipe(Pipe): # 转换tag for name, dataset in data_bundle.datasets.items(): dataset.drop(lambda x: "-DOCSTART-" in x['raw_words']) - dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') - dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') + dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk', num_proc=self.num_proc) + dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner', num_proc=self.num_proc) _add_words_field(data_bundle, lower=self.lower) @@ -194,10 +202,7 @@ class Conll2003Pipe(Pipe): tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') data_bundle.set_vocab(tgt_vocab, 'chunk') - - input_fields = ['words', 'seq_len'] - target_fields = ['pos', 'ner', 'chunk', 'seq_len'] - + for name, dataset in data_bundle.iter_datasets(): dataset.add_seq_len('words') @@ -256,7 +261,7 @@ class _CNNERPipe(Pipe): """ - def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): + def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False, num_proc: int = 0): r""" :param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 @@ -270,12 +275,16 @@ class _CNNERPipe(Pipe): if encoding_type == 'bio': self.convert_tag = iob2 elif encoding_type == 'bioes': - self.convert_tag = lambda words: iob2bioes(iob2(words)) + def func(words): + return iob2bioes(iob2(words)) + # self.convert_tag = lambda words: iob2bioes(iob2(words)) + self.convert_tag = func else: raise ValueError("encoding_type only supports `bio` and `bioes`.") self.bigrams = bigrams self.trigrams = trigrams + self.num_proc = num_proc def process(self, data_bundle: DataBundle) -> DataBundle: r""" @@ -296,29 +305,31 @@ class _CNNERPipe(Pipe): """ # 转换tag for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') + dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target', num_proc=self.num_proc) _add_chars_field(data_bundle, lower=False) input_field_names = ['chars'] + + def bigrams(chars): + return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])] + + def trigrams(chars): + return [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)] + if self.bigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], - field_name='chars', new_field_name='bigrams') + dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) input_field_names.append('bigrams') if self.trigrams: for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in - zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], - field_name='chars', new_field_name='trigrams') + dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) input_field_names.append('trigrams') # index _indexize(data_bundle, input_field_names, 'target') - input_fields = ['target', 'seq_len'] + input_field_names - target_fields = ['target', 'seq_len'] - for name, dataset in data_bundle.iter_datasets(): dataset.add_seq_len('chars') diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py index 5983201e..2937f147 100644 --- a/fastNLP/io/pipe/cws.py +++ b/fastNLP/io/pipe/cws.py @@ -157,7 +157,8 @@ class CWSPipe(Pipe): """ - def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): + def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, + bigrams=False, trigrams=False, num_proc: int = 0): r""" :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None @@ -176,6 +177,7 @@ class CWSPipe(Pipe): self.bigrams = bigrams self.trigrams = trigrams self.replace_num_alpha = replace_num_alpha + self.num_proc = num_proc def _tokenize(self, data_bundle): r""" @@ -213,7 +215,7 @@ class CWSPipe(Pipe): for name, dataset in data_bundle.iter_datasets(): dataset.apply_field(split_word_into_chars, field_name='chars', - new_field_name='chars') + new_field_name='chars', num_proc=self.num_proc) return data_bundle def process(self, data_bundle: DataBundle) -> DataBundle: @@ -233,33 +235,40 @@ class CWSPipe(Pipe): data_bundle.copy_field('raw_words', 'chars') if self.replace_num_alpha: - data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars') - data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars') + data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars', num_proc=self.num_proc) + data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars', num_proc=self.num_proc) self._tokenize(data_bundle) + + def func1(chars): + return self.word_lens_to_tags(map(len, chars)) + + def func2(chars): + return list(chain(*chars)) for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name='chars', - new_field_name='target') - dataset.apply_field(lambda chars: list(chain(*chars)), field_name='chars', - new_field_name='chars') + dataset.apply_field(func1, field_name='chars', new_field_name='target', num_proc=self.num_proc) + dataset.apply_field(func2, field_name='chars', new_field_name='chars', num_proc=self.num_proc) input_field_names = ['chars'] + + def bigram(chars): + return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])] + + def trigrams(chars): + return [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)] + if self.bigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], - field_name='chars', new_field_name='bigrams') + dataset.apply_field(bigram, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) input_field_names.append('bigrams') if self.trigrams: for name, dataset in data_bundle.iter_datasets(): - dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in - zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], - field_name='chars', new_field_name='trigrams') + dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) input_field_names.append('trigrams') _indexize(data_bundle, input_field_names, 'target') - - input_fields = ['target', 'seq_len'] + input_field_names - target_fields = ['target', 'seq_len'] + for name, dataset in data_bundle.iter_datasets(): dataset.add_seq_len('chars') diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index baebdbaa..5b9981c2 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -23,6 +23,7 @@ __all__ = [ "GranularizePipe", "MachingTruncatePipe", ] +from functools import partial from fastNLP.core.log import logger from .pipe import Pipe @@ -63,7 +64,7 @@ class MatchingBertPipe(Pipe): """ - def __init__(self, lower=False, tokenizer: str = 'raw'): + def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): r""" :param bool lower: 是否将word小写化。 @@ -73,6 +74,7 @@ class MatchingBertPipe(Pipe): self.lower = bool(lower) self.tokenizer = get_tokenizer(tokenize_method=tokenizer) + self.num_proc = num_proc def _tokenize(self, data_bundle, field_names, new_field_names): r""" @@ -84,8 +86,7 @@ class MatchingBertPipe(Pipe): """ for name, dataset in data_bundle.iter_datasets(): for field_name, new_field_name in zip(field_names, new_field_names): - dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, - new_field_name=new_field_name) + dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) return data_bundle def process(self, data_bundle): @@ -124,8 +125,8 @@ class MatchingBertPipe(Pipe): words = words0 + ['[SEP]'] + words1 return words - for name, dataset in data_bundle.datasets.items(): - dataset.apply(concat, new_field_name='words') + for name, dataset in data_bundle.iter_datasets(): + dataset.apply(concat, new_field_name='words', num_proc=self.num_proc) dataset.delete_field('words1') dataset.delete_field('words2') @@ -155,10 +156,7 @@ class MatchingBertPipe(Pipe): data_bundle.set_vocab(word_vocab, 'words') data_bundle.set_vocab(target_vocab, 'target') - - input_fields = ['words', 'seq_len'] - target_fields = ['target'] - + for name, dataset in data_bundle.iter_datasets(): dataset.add_seq_len('words') @@ -223,7 +221,7 @@ class MatchingPipe(Pipe): """ - def __init__(self, lower=False, tokenizer: str = 'raw'): + def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): r""" :param bool lower: 是否将所有raw_words转为小写。 @@ -233,6 +231,7 @@ class MatchingPipe(Pipe): self.lower = bool(lower) self.tokenizer = get_tokenizer(tokenize_method=tokenizer) + self.num_proc = num_proc def _tokenize(self, data_bundle, field_names, new_field_names): r""" @@ -244,8 +243,7 @@ class MatchingPipe(Pipe): """ for name, dataset in data_bundle.iter_datasets(): for field_name, new_field_name in zip(field_names, new_field_names): - dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, - new_field_name=new_field_name) + dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) return data_bundle def process(self, data_bundle): @@ -300,10 +298,7 @@ class MatchingPipe(Pipe): data_bundle.set_vocab(word_vocab, 'words1') data_bundle.set_vocab(target_vocab, 'target') - - input_fields = ['words1', 'words2', 'seq_len1', 'seq_len2'] - target_fields = ['target'] - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len('words1', 'seq_len1') dataset.add_seq_len('words2', 'seq_len2') @@ -342,8 +337,8 @@ class MNLIPipe(MatchingPipe): class LCQMCPipe(MatchingPipe): - def __init__(self, tokenizer='cn=char'): - super().__init__(tokenizer=tokenizer) + def __init__(self, tokenizer='cn=char', num_proc=0): + super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): data_bundle = LCQMCLoader().load(paths) @@ -354,8 +349,8 @@ class LCQMCPipe(MatchingPipe): class CNXNLIPipe(MatchingPipe): - def __init__(self, tokenizer='cn-char'): - super().__init__(tokenizer=tokenizer) + def __init__(self, tokenizer='cn-char', num_proc=0): + super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): data_bundle = CNXNLILoader().load(paths) @@ -367,8 +362,8 @@ class CNXNLIPipe(MatchingPipe): class BQCorpusPipe(MatchingPipe): - def __init__(self, tokenizer='cn-char'): - super().__init__(tokenizer=tokenizer) + def __init__(self, tokenizer='cn-char', num_proc=0): + super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): data_bundle = BQCorpusLoader().load(paths) @@ -379,9 +374,10 @@ class BQCorpusPipe(MatchingPipe): class RenamePipe(Pipe): - def __init__(self, task='cn-nli'): + def __init__(self, task='cn-nli', num_proc=0): super().__init__() self.task = task + self.num_proc = num_proc def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset if (self.task == 'cn-nli'): @@ -419,9 +415,10 @@ class RenamePipe(Pipe): class GranularizePipe(Pipe): - def __init__(self, task=None): + def __init__(self, task=None, num_proc=0): super().__init__() self.task = task + self.num_proc = num_proc def _granularize(self, data_bundle, tag_map): r""" @@ -434,8 +431,7 @@ class GranularizePipe(Pipe): """ for name in list(data_bundle.datasets.keys()): dataset = data_bundle.get_dataset(name) - dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', - new_field_name='target') + dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', new_field_name='target') dataset.drop(lambda ins: ins['target'] == -100) data_bundle.set_dataset(dataset, name) return data_bundle @@ -462,8 +458,8 @@ class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len class LCQMCBertPipe(MatchingBertPipe): - def __init__(self, tokenizer='cn=char'): - super().__init__(tokenizer=tokenizer) + def __init__(self, tokenizer='cn=char', num_proc=0): + super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): data_bundle = LCQMCLoader().load(paths) @@ -475,8 +471,8 @@ class LCQMCBertPipe(MatchingBertPipe): class BQCorpusBertPipe(MatchingBertPipe): - def __init__(self, tokenizer='cn-char'): - super().__init__(tokenizer=tokenizer) + def __init__(self, tokenizer='cn-char', num_proc=0): + super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): data_bundle = BQCorpusLoader().load(paths) @@ -488,8 +484,8 @@ class BQCorpusBertPipe(MatchingBertPipe): class CNXNLIBertPipe(MatchingBertPipe): - def __init__(self, tokenizer='cn-char'): - super().__init__(tokenizer=tokenizer) + def __init__(self, tokenizer='cn-char', num_proc=0): + super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): data_bundle = CNXNLILoader().load(paths) @@ -502,9 +498,10 @@ class CNXNLIBertPipe(MatchingBertPipe): class TruncateBertPipe(Pipe): - def __init__(self, task='cn'): + def __init__(self, task='cn', num_proc=0): super().__init__() self.task = task + self.num_proc = num_proc def _truncate(self, sentence_index:list, sep_index_vocab): # 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index @@ -528,7 +525,8 @@ class TruncateBertPipe(Pipe): for name in data_bundle.datasets.keys(): dataset = data_bundle.get_dataset(name) sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') - dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') + dataset.apply_field(partial(self._truncate, sep_index_vocab=sep_index_vocab), field_name='words', + new_field_name='words', num_proc=self.num_proc) # truncate之后需要更新seq_len dataset.add_seq_len(field_name='words') diff --git a/fastNLP/io/pipe/summarization.py b/fastNLP/io/pipe/summarization.py index 359801c4..b413890b 100644 --- a/fastNLP/io/pipe/summarization.py +++ b/fastNLP/io/pipe/summarization.py @@ -1,6 +1,7 @@ r"""undocumented""" import os import numpy as np +from functools import partial from .pipe import Pipe from .utils import _drop_empty_instance @@ -25,7 +26,7 @@ class ExtCNNDMPipe(Pipe): :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" """ - def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): + def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False, num_proc=0): r""" :param vocab_size: int, 词表大小 @@ -39,6 +40,7 @@ class ExtCNNDMPipe(Pipe): self.sent_max_len = sent_max_len self.doc_max_timesteps = doc_max_timesteps self.domain = domain + self.num_proc = num_proc def process(self, data_bundle: DataBundle): r""" @@ -65,18 +67,29 @@ class ExtCNNDMPipe(Pipe): error_msg = 'vocab file is not defined!' print(error_msg) raise RuntimeError(error_msg) - data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') - data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') - data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') + data_bundle.apply_field(_lower_text, field_name='text', new_field_name='text', num_proc=self.num_proc) + data_bundle.apply_field(_lower_text, field_name='summary', new_field_name='summary', num_proc=self.num_proc) + data_bundle.apply_field(_split_list, field_name='text', new_field_name='text_wd', num_proc=self.num_proc) + # data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') + # data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') + # data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name='target') - data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') + data_bundle.apply_field(partial(_pad_sent, sent_max_len=self.sent_max_len), field_name="text_wd", + new_field_name="words", num_proc=self.num_proc) + # data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') # db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") # pad document - data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') - data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') - data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') + data_bundle.apply_field(partial(_pad_doc, sent_max_len=self.sent_max_len, doc_max_timesteps=self.doc_max_timesteps), + field_name="words", new_field_name="words", num_proc=self.num_proc) + data_bundle.apply_field(partial(_sent_mask, doc_max_timesteps=self.doc_max_timesteps), field_name="words", + new_field_name="seq_len", num_proc=self.num_proc) + data_bundle.apply_field(partial(_pad_label, doc_max_timesteps=self.doc_max_timesteps), field_name="target", + new_field_name="target", num_proc=self.num_proc) + # data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') + # data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') + # data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') data_bundle = _drop_empty_instance(data_bundle, "label") diff --git a/tests/io/pipe/test_classification.py b/tests/io/pipe/test_classification.py index 31174862..99d63149 100755 --- a/tests/io/pipe/test_classification.py +++ b/tests/io/pipe/test_classification.py @@ -12,14 +12,24 @@ class TestClassificationPipe: def test_process_from_file(self): for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: print(pipe) - data_bundle = pipe(tokenizer='raw').process_from_file() + data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file() print(data_bundle) + def test_process_from_file_proc(self, num_proc=2): + for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: + print(pipe) + data_bundle = pipe(tokenizer='raw', num_proc=num_proc).process_from_file() + print(data_bundle) class TestRunPipe: def test_load(self): for pipe in [IMDBPipe]: - data_bundle = pipe(tokenizer='raw').process_from_file('tests/data_for_tests/io/imdb') + data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file('tests/data_for_tests/io/imdb') + print(data_bundle) + + def test_load_proc(self): + for pipe in [IMDBPipe]: + data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file('tests/data_for_tests/io/imdb') print(data_bundle) @@ -31,7 +41,7 @@ class TestCNClassificationPipe: print(data_bundle) -@pytest.mark.skipif('download' not in os.environ, reason="Skip download") +# @pytest.mark.skipif('download' not in os.environ, reason="Skip download") class TestRunClassificationPipe: def test_process_from_file(self): data_set_dict = { @@ -71,9 +81,9 @@ class TestRunClassificationPipe: path, pipe, data_set, vocab, warns = v if 'Chn' not in k: if warns: - data_bundle = pipe(tokenizer='raw').process_from_file(path) + data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path) else: - data_bundle = pipe(tokenizer='raw').process_from_file(path) + data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path) else: data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) @@ -87,3 +97,61 @@ class TestRunClassificationPipe: for name, vocabs in data_bundle.iter_vocabs(): assert(name in vocab.keys()) assert(vocab[name] == len(vocabs)) + + def test_process_from_file_proc(self): + data_set_dict = { + 'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, + {'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2}, + False), + 'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullPipe, + {'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5}, + False), + 'sst-2': ('tests/data_for_tests/io/SST-2', SST2Pipe, + {'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2}, + True), + 'sst': ('tests/data_for_tests/io/SST', SSTPipe, + {'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5}, + False), + 'imdb': ('tests/data_for_tests/io/imdb', IMDBPipe, + {'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2}, + False), + 'ag': ('tests/data_for_tests/io/ag', AGsNewsPipe, + {'train': 4, 'test': 5}, {'words': 257, 'target': 4}, + False), + 'dbpedia': ('tests/data_for_tests/io/dbpedia', DBPediaPipe, + {'train': 14, 'test': 5}, {'words': 496, 'target': 14}, + False), + 'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, + {'train': 6, 'dev': 6, 'test': 6}, + {'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2}, + False), + 'Chn-THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsPipe, + {'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9}, + False), + 'Chn-WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, + {'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2}, + False), + } + for k, v in data_set_dict.items(): + path, pipe, data_set, vocab, warns = v + if 'Chn' not in k: + if warns: + data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path) + else: + data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path) + else: + # if k == 'ChnSentiCorp': + # data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) + # else: + data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(path) + + assert(isinstance(data_bundle, DataBundle)) + assert(len(data_set) == data_bundle.num_dataset) + for name, dataset in data_bundle.iter_datasets(): + assert(name in data_set.keys()) + assert(data_set[name] == len(dataset)) + + assert(len(vocab) == data_bundle.num_vocab) + for name, vocabs in data_bundle.iter_vocabs(): + assert(name in vocab.keys()) + assert(vocab[name] == len(vocabs)) \ No newline at end of file diff --git a/tests/io/pipe/test_conll.py b/tests/io/pipe/test_conll.py index e4000ae3..4f02a8ee 100755 --- a/tests/io/pipe/test_conll.py +++ b/tests/io/pipe/test_conll.py @@ -22,6 +22,12 @@ class TestRunPipe: data_bundle = pipe().process_from_file('tests/data_for_tests/conll_2003_example.txt') print(data_bundle) + def test_conll2003_proc(self): + for pipe in [Conll2003Pipe, Conll2003NERPipe]: + print(pipe) + data_bundle = pipe(num_proc=2).process_from_file('tests/data_for_tests/conll_2003_example.txt') + print(data_bundle) + class TestNERPipe: def test_process_from_file(self): @@ -37,12 +43,33 @@ class TestNERPipe: data_bundle = pipe(encoding_type='bioes').process_from_file(f'tests/data_for_tests/io/{k}') print(data_bundle) + def test_process_from_file_proc(self): + data_dict = { + 'weibo_NER': WeiboNERPipe, + 'peopledaily': PeopleDailyPipe, + 'MSRA_NER': MsraNERPipe, + } + for k, v in data_dict.items(): + pipe = v + data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(f'tests/data_for_tests/io/{k}') + print(data_bundle) + data_bundle = pipe(encoding_type='bioes', num_proc=2).process_from_file(f'tests/data_for_tests/io/{k}') + print(data_bundle) + class TestConll2003Pipe: def test_conll(self): data_bundle = Conll2003Pipe().process_from_file('tests/data_for_tests/io/conll2003') print(data_bundle) + def test_conll_proc(self): + data_bundle = Conll2003Pipe(num_proc=2).process_from_file('tests/data_for_tests/io/conll2003') + print(data_bundle) + def test_OntoNotes(self): data_bundle = OntoNotesNERPipe().process_from_file('tests/data_for_tests/io/OntoNotes') print(data_bundle) + + def test_OntoNotes_proc(self): + data_bundle = OntoNotesNERPipe(num_proc=2).process_from_file('tests/data_for_tests/io/OntoNotes') + print(data_bundle) diff --git a/tests/io/pipe/test_cws.py b/tests/io/pipe/test_cws.py index 895234b2..53d9af17 100755 --- a/tests/io/pipe/test_cws.py +++ b/tests/io/pipe/test_cws.py @@ -28,12 +28,25 @@ class TestRunCWSPipe: def test_process_from_file(self): dataset_names = ['msra', 'cityu', 'as', 'pku'] for dataset_name in dataset_names: - data_bundle = CWSPipe(bigrams=True, trigrams=True).\ + data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=0).\ process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') print(data_bundle) def test_replace_number(self): - data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True).\ + data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=0).\ + process_from_file(f'tests/data_for_tests/io/cws_pku') + for word in ['<', '>', '']: + assert(data_bundle.get_vocab('chars').to_index(word) != 1) + + def test_process_from_file_proc(self): + dataset_names = ['msra', 'cityu', 'as', 'pku'] + for dataset_name in dataset_names: + data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=2).\ + process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') + print(data_bundle) + + def test_replace_number_proc(self): + data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=2).\ process_from_file(f'tests/data_for_tests/io/cws_pku') for word in ['<', '>', '']: assert(data_bundle.get_vocab('chars').to_index(word) != 1) diff --git a/tests/io/pipe/test_matching.py b/tests/io/pipe/test_matching.py index 23c8fd70..70043638 100755 --- a/tests/io/pipe/test_matching.py +++ b/tests/io/pipe/test_matching.py @@ -69,6 +69,47 @@ class TestRunMatchingPipe: name, vocabs = y assert(x + 1 if name == 'words' else x == len(vocabs)) + def test_load_proc(self): + data_set_dict = { + 'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True), + 'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), + 'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), + 'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), + 'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), + 'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), + 'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), + } + for k, v in data_set_dict.items(): + path, pipe1, pipe2, data_set, vocab, warns = v + if warns: + data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path) + data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path) + else: + data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path) + data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path) + + assert (isinstance(data_bundle1, DataBundle)) + assert (len(data_set) == data_bundle1.num_dataset) + print(k) + print(data_bundle1) + print(data_bundle2) + for x, y in zip(data_set, data_bundle1.iter_datasets()): + name, dataset = y + assert (x == len(dataset)) + assert (len(data_set) == data_bundle2.num_dataset) + for x, y in zip(data_set, data_bundle2.iter_datasets()): + name, dataset = y + assert (x == len(dataset)) + + assert (len(vocab) == data_bundle1.num_vocab) + for x, y in zip(vocab, data_bundle1.iter_vocabs()): + name, vocabs = y + assert (x == len(vocabs)) + assert (len(vocab) == data_bundle2.num_vocab) + for x, y in zip(vocab, data_bundle1.iter_vocabs()): + name, vocabs = y + assert (x + 1 if name == 'words' else x == len(vocabs)) + @pytest.mark.skipif('download' not in os.environ, reason="Skip download") def test_spacy(self): data_set_dict = { diff --git a/tests/io/pipe/test_summary.py b/tests/io/pipe/test_summary.py index b8692791..12d81a1d 100755 --- a/tests/io/pipe/test_summary.py +++ b/tests/io/pipe/test_summary.py @@ -69,3 +69,45 @@ class TestRunExtCNNDMPipe: db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) assert(isinstance(db5, DataBundle)) + def test_load_proc(self): + data_dir = 'tests/data_for_tests/io/cnndm' + vocab_size = 100000 + VOCAL_FILE = 'tests/data_for_tests/io/cnndm/vocab' + sent_max_len = 100 + doc_max_timesteps = 50 + dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, + vocab_path=VOCAL_FILE, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps, num_proc=2) + dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, + vocab_path=VOCAL_FILE, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps, + domain=True, num_proc=2) + db = dbPipe.process_from_file(data_dir) + db2 = dbPipe2.process_from_file(data_dir) + + assert(isinstance(db, DataBundle)) + assert(isinstance(db2, DataBundle)) + + dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps, + domain=True, num_proc=2) + db3 = dbPipe3.process_from_file(data_dir) + assert(isinstance(db3, DataBundle)) + + with pytest.raises(RuntimeError): + dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps, num_proc=2) + db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) + + dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size, + vocab_path=VOCAL_FILE, + sent_max_len=sent_max_len, + doc_max_timesteps=doc_max_timesteps, num_proc=2) + db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) + assert(isinstance(db5, DataBundle)) + +