@@ -37,10 +37,11 @@ from fastNLP.core.log import logger | |||
class CLSBasePipe(Pipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en', num_proc=0): | |||
super().__init__() | |||
self.lower = lower | |||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | |||
self.num_proc = num_proc | |||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | |||
r""" | |||
@@ -53,7 +54,8 @@ class CLSBasePipe(Pipe): | |||
""" | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, | |||
num_proc=self.num_proc) | |||
return data_bundle | |||
@@ -117,7 +119,7 @@ class YelpFullPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy', num_proc=0): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
@@ -125,7 +127,7 @@ class YelpFullPipe(CLSBasePipe): | |||
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||
self.granularity = granularity | |||
@@ -191,13 +193,13 @@ class YelpPolarityPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
r""" | |||
@@ -233,13 +235,13 @@ class AGsNewsPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
r""" | |||
@@ -274,13 +276,13 @@ class DBPediaPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
r""" | |||
@@ -315,7 +317,7 @@ class SSTPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | |||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy', num_proc=0): | |||
r""" | |||
:param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | |||
@@ -325,7 +327,7 @@ class SSTPipe(CLSBasePipe): | |||
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
self.subtree = subtree | |||
self.train_tree = train_subtree | |||
self.lower = lower | |||
@@ -407,13 +409,13 @@ class SST2Pipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower=False, tokenizer='raw'): | |||
def __init__(self, lower=False, tokenizer='raw', num_proc=0): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
r""" | |||
@@ -452,13 +454,13 @@ class IMDBPipe(CLSBasePipe): | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
self.lower = lower | |||
def process(self, data_bundle: DataBundle): | |||
@@ -483,7 +485,7 @@ class IMDBPipe(CLSBasePipe): | |||
return raw_words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words') | |||
dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words', num_proc=self.num_proc) | |||
data_bundle = super().process(data_bundle) | |||
@@ -527,7 +529,7 @@ class ChnSentiCorpPipe(Pipe): | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
def __init__(self, bigrams=False, trigrams=False, num_proc: int = 0): | |||
r""" | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
@@ -541,10 +543,11 @@ class ChnSentiCorpPipe(Pipe): | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
self.num_proc = num_proc | |||
def _tokenize(self, data_bundle): | |||
r""" | |||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | |||
将 DataSet 中的"复旦大学"拆分为 ["复", "旦", "大", "学"] . 未来可以通过扩展这个函数实现分词。 | |||
:param data_bundle: | |||
:return: | |||
@@ -571,24 +574,26 @@ class ChnSentiCorpPipe(Pipe): | |||
data_bundle = self._tokenize(data_bundle) | |||
input_field_names = ['chars'] | |||
def bigrams(chars): | |||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||
def trigrams(chars): | |||
return [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
dataset.apply_field(bigrams,field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||
input_field_names.append('trigrams') | |||
# index | |||
_indexize(data_bundle, input_field_names, 'target') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target'] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len('chars') | |||
@@ -637,8 +642,8 @@ class THUCNewsPipe(CLSBasePipe): | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
super().__init__() | |||
def __init__(self, bigrams=False, trigrams=False, num_proc=0): | |||
super().__init__(num_proc=num_proc) | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
@@ -653,7 +658,7 @@ class THUCNewsPipe(CLSBasePipe): | |||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) | |||
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
@@ -680,17 +685,21 @@ class THUCNewsPipe(CLSBasePipe): | |||
input_field_names = ['chars'] | |||
def bigrams(chars): | |||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||
def trigrams(chars): | |||
return [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||
# n-grams | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||
input_field_names.append('trigrams') | |||
# index | |||
@@ -700,9 +709,6 @@ class THUCNewsPipe(CLSBasePipe): | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target'] | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
@@ -746,8 +752,8 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
super().__init__() | |||
def __init__(self, bigrams=False, trigrams=False, num_proc=0): | |||
super().__init__(num_proc=num_proc) | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
@@ -758,7 +764,8 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) | |||
dataset.apply_field(self._chracter_split, field_name=field_name, | |||
new_field_name=new_field_name, num_proc=self.num_proc) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
@@ -779,20 +786,19 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||
# CWS(tokenize) | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') | |||
input_field_names = ['chars'] | |||
def bigrams(chars): | |||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||
def trigrams(chars): | |||
return [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||
# n-grams | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') | |||
@@ -801,9 +807,6 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target'] | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
@@ -817,13 +820,13 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||
return data_bundle | |||
class MRPipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
@@ -840,13 +843,13 @@ class MRPipe(CLSBasePipe): | |||
class R8Pipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc = 0): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
@@ -863,13 +866,13 @@ class R8Pipe(CLSBasePipe): | |||
class R52Pipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
@@ -886,13 +889,13 @@ class R52Pipe(CLSBasePipe): | |||
class OhsumedPipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
@@ -909,13 +912,13 @@ class OhsumedPipe(CLSBasePipe): | |||
class NG20Pipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
@@ -30,7 +30,7 @@ class _NERPipe(Pipe): | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False, num_proc=0): | |||
r""" | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
@@ -39,10 +39,14 @@ class _NERPipe(Pipe): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
elif encoding_type == 'bioes': | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
def func(words): | |||
return iob2bioes(iob2(words)) | |||
# self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.convert_tag = func | |||
else: | |||
raise ValueError("encoding_type only supports `bio` and `bioes`.") | |||
self.lower = lower | |||
self.num_proc = num_proc | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
r""" | |||
@@ -60,16 +64,13 @@ class _NERPipe(Pipe): | |||
""" | |||
# 转换tag | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') | |||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target', num_proc=self.num_proc) | |||
_add_words_field(data_bundle, lower=self.lower) | |||
# index | |||
_indexize(data_bundle) | |||
input_fields = ['target', 'words', 'seq_len'] | |||
target_fields = ['target', 'seq_len'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('words') | |||
@@ -144,7 +145,7 @@ class Conll2003Pipe(Pipe): | |||
""" | |||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | |||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, num_proc: int = 0): | |||
r""" | |||
:param str chunk_encoding_type: 支持bioes, bio。 | |||
@@ -154,16 +155,23 @@ class Conll2003Pipe(Pipe): | |||
if chunk_encoding_type == 'bio': | |||
self.chunk_convert_tag = iob2 | |||
elif chunk_encoding_type == 'bioes': | |||
self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||
def func1(tags): | |||
return iob2bioes(iob2(tags)) | |||
# self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||
self.chunk_convert_tag = func1 | |||
else: | |||
raise ValueError("chunk_encoding_type only supports `bio` and `bioes`.") | |||
if ner_encoding_type == 'bio': | |||
self.ner_convert_tag = iob2 | |||
elif ner_encoding_type == 'bioes': | |||
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||
def func2(tags): | |||
return iob2bioes(iob2(tags)) | |||
# self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||
self.ner_convert_tag = func2 | |||
else: | |||
raise ValueError("ner_encoding_type only supports `bio` and `bioes`.") | |||
self.lower = lower | |||
self.num_proc = num_proc | |||
def process(self, data_bundle) -> DataBundle: | |||
r""" | |||
@@ -182,8 +190,8 @@ class Conll2003Pipe(Pipe): | |||
# 转换tag | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.drop(lambda x: "-DOCSTART-" in x['raw_words']) | |||
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') | |||
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') | |||
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk', num_proc=self.num_proc) | |||
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner', num_proc=self.num_proc) | |||
_add_words_field(data_bundle, lower=self.lower) | |||
@@ -194,10 +202,7 @@ class Conll2003Pipe(Pipe): | |||
tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | |||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | |||
data_bundle.set_vocab(tgt_vocab, 'chunk') | |||
input_fields = ['words', 'seq_len'] | |||
target_fields = ['pos', 'ner', 'chunk', 'seq_len'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('words') | |||
@@ -256,7 +261,7 @@ class _CNNERPipe(Pipe): | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | |||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False, num_proc: int = 0): | |||
r""" | |||
:param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
@@ -270,12 +275,16 @@ class _CNNERPipe(Pipe): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
elif encoding_type == 'bioes': | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
def func(words): | |||
return iob2bioes(iob2(words)) | |||
# self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.convert_tag = func | |||
else: | |||
raise ValueError("encoding_type only supports `bio` and `bioes`.") | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
self.num_proc = num_proc | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
r""" | |||
@@ -296,29 +305,31 @@ class _CNNERPipe(Pipe): | |||
""" | |||
# 转换tag | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') | |||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target', num_proc=self.num_proc) | |||
_add_chars_field(data_bundle, lower=False) | |||
input_field_names = ['chars'] | |||
def bigrams(chars): | |||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||
def trigrams(chars): | |||
return [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||
input_field_names.append('trigrams') | |||
# index | |||
_indexize(data_bundle, input_field_names, 'target') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target', 'seq_len'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('chars') | |||
@@ -157,7 +157,8 @@ class CWSPipe(Pipe): | |||
""" | |||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | |||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, | |||
bigrams=False, trigrams=False, num_proc: int = 0): | |||
r""" | |||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | |||
@@ -176,6 +177,7 @@ class CWSPipe(Pipe): | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
self.replace_num_alpha = replace_num_alpha | |||
self.num_proc = num_proc | |||
def _tokenize(self, data_bundle): | |||
r""" | |||
@@ -213,7 +215,7 @@ class CWSPipe(Pipe): | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(split_word_into_chars, field_name='chars', | |||
new_field_name='chars') | |||
new_field_name='chars', num_proc=self.num_proc) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
@@ -233,33 +235,40 @@ class CWSPipe(Pipe): | |||
data_bundle.copy_field('raw_words', 'chars') | |||
if self.replace_num_alpha: | |||
data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars') | |||
data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars') | |||
data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars', num_proc=self.num_proc) | |||
data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars', num_proc=self.num_proc) | |||
self._tokenize(data_bundle) | |||
def func1(chars): | |||
return self.word_lens_to_tags(map(len, chars)) | |||
def func2(chars): | |||
return list(chain(*chars)) | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name='chars', | |||
new_field_name='target') | |||
dataset.apply_field(lambda chars: list(chain(*chars)), field_name='chars', | |||
new_field_name='chars') | |||
dataset.apply_field(func1, field_name='chars', new_field_name='target', num_proc=self.num_proc) | |||
dataset.apply_field(func2, field_name='chars', new_field_name='chars', num_proc=self.num_proc) | |||
input_field_names = ['chars'] | |||
def bigram(chars): | |||
return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||
def trigrams(chars): | |||
return [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
dataset.apply_field(bigram, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||
input_field_names.append('trigrams') | |||
_indexize(data_bundle, input_field_names, 'target') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target', 'seq_len'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('chars') | |||
@@ -23,6 +23,7 @@ __all__ = [ | |||
"GranularizePipe", | |||
"MachingTruncatePipe", | |||
] | |||
from functools import partial | |||
from fastNLP.core.log import logger | |||
from .pipe import Pipe | |||
@@ -63,7 +64,7 @@ class MatchingBertPipe(Pipe): | |||
""" | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): | |||
r""" | |||
:param bool lower: 是否将word小写化。 | |||
@@ -73,6 +74,7 @@ class MatchingBertPipe(Pipe): | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | |||
self.num_proc = num_proc | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
r""" | |||
@@ -84,8 +86,7 @@ class MatchingBertPipe(Pipe): | |||
""" | |||
for name, dataset in data_bundle.iter_datasets(): | |||
for field_name, new_field_name in zip(field_names, new_field_names): | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
@@ -124,8 +125,8 @@ class MatchingBertPipe(Pipe): | |||
words = words0 + ['[SEP]'] + words1 | |||
return words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply(concat, new_field_name='words') | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply(concat, new_field_name='words', num_proc=self.num_proc) | |||
dataset.delete_field('words1') | |||
dataset.delete_field('words2') | |||
@@ -155,10 +156,7 @@ class MatchingBertPipe(Pipe): | |||
data_bundle.set_vocab(word_vocab, 'words') | |||
data_bundle.set_vocab(target_vocab, 'target') | |||
input_fields = ['words', 'seq_len'] | |||
target_fields = ['target'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('words') | |||
@@ -223,7 +221,7 @@ class MatchingPipe(Pipe): | |||
""" | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): | |||
r""" | |||
:param bool lower: 是否将所有raw_words转为小写。 | |||
@@ -233,6 +231,7 @@ class MatchingPipe(Pipe): | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | |||
self.num_proc = num_proc | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
r""" | |||
@@ -244,8 +243,7 @@ class MatchingPipe(Pipe): | |||
""" | |||
for name, dataset in data_bundle.iter_datasets(): | |||
for field_name, new_field_name in zip(field_names, new_field_names): | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
@@ -300,10 +298,7 @@ class MatchingPipe(Pipe): | |||
data_bundle.set_vocab(word_vocab, 'words1') | |||
data_bundle.set_vocab(target_vocab, 'target') | |||
input_fields = ['words1', 'words2', 'seq_len1', 'seq_len2'] | |||
target_fields = ['target'] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len('words1', 'seq_len1') | |||
dataset.add_seq_len('words2', 'seq_len2') | |||
@@ -342,8 +337,8 @@ class MNLIPipe(MatchingPipe): | |||
class LCQMCPipe(MatchingPipe): | |||
def __init__(self, tokenizer='cn=char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def __init__(self, tokenizer='cn=char', num_proc=0): | |||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
data_bundle = LCQMCLoader().load(paths) | |||
@@ -354,8 +349,8 @@ class LCQMCPipe(MatchingPipe): | |||
class CNXNLIPipe(MatchingPipe): | |||
def __init__(self, tokenizer='cn-char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def __init__(self, tokenizer='cn-char', num_proc=0): | |||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
data_bundle = CNXNLILoader().load(paths) | |||
@@ -367,8 +362,8 @@ class CNXNLIPipe(MatchingPipe): | |||
class BQCorpusPipe(MatchingPipe): | |||
def __init__(self, tokenizer='cn-char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def __init__(self, tokenizer='cn-char', num_proc=0): | |||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
data_bundle = BQCorpusLoader().load(paths) | |||
@@ -379,9 +374,10 @@ class BQCorpusPipe(MatchingPipe): | |||
class RenamePipe(Pipe): | |||
def __init__(self, task='cn-nli'): | |||
def __init__(self, task='cn-nli', num_proc=0): | |||
super().__init__() | |||
self.task = task | |||
self.num_proc = num_proc | |||
def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset | |||
if (self.task == 'cn-nli'): | |||
@@ -419,9 +415,10 @@ class RenamePipe(Pipe): | |||
class GranularizePipe(Pipe): | |||
def __init__(self, task=None): | |||
def __init__(self, task=None, num_proc=0): | |||
super().__init__() | |||
self.task = task | |||
self.num_proc = num_proc | |||
def _granularize(self, data_bundle, tag_map): | |||
r""" | |||
@@ -434,8 +431,7 @@ class GranularizePipe(Pipe): | |||
""" | |||
for name in list(data_bundle.datasets.keys()): | |||
dataset = data_bundle.get_dataset(name) | |||
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', | |||
new_field_name='target') | |||
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', new_field_name='target') | |||
dataset.drop(lambda ins: ins['target'] == -100) | |||
data_bundle.set_dataset(dataset, name) | |||
return data_bundle | |||
@@ -462,8 +458,8 @@ class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len | |||
class LCQMCBertPipe(MatchingBertPipe): | |||
def __init__(self, tokenizer='cn=char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def __init__(self, tokenizer='cn=char', num_proc=0): | |||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
data_bundle = LCQMCLoader().load(paths) | |||
@@ -475,8 +471,8 @@ class LCQMCBertPipe(MatchingBertPipe): | |||
class BQCorpusBertPipe(MatchingBertPipe): | |||
def __init__(self, tokenizer='cn-char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def __init__(self, tokenizer='cn-char', num_proc=0): | |||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
data_bundle = BQCorpusLoader().load(paths) | |||
@@ -488,8 +484,8 @@ class BQCorpusBertPipe(MatchingBertPipe): | |||
class CNXNLIBertPipe(MatchingBertPipe): | |||
def __init__(self, tokenizer='cn-char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def __init__(self, tokenizer='cn-char', num_proc=0): | |||
super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||
def process_from_file(self, paths=None): | |||
data_bundle = CNXNLILoader().load(paths) | |||
@@ -502,9 +498,10 @@ class CNXNLIBertPipe(MatchingBertPipe): | |||
class TruncateBertPipe(Pipe): | |||
def __init__(self, task='cn'): | |||
def __init__(self, task='cn', num_proc=0): | |||
super().__init__() | |||
self.task = task | |||
self.num_proc = num_proc | |||
def _truncate(self, sentence_index:list, sep_index_vocab): | |||
# 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index | |||
@@ -528,7 +525,8 @@ class TruncateBertPipe(Pipe): | |||
for name in data_bundle.datasets.keys(): | |||
dataset = data_bundle.get_dataset(name) | |||
sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') | |||
dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') | |||
dataset.apply_field(partial(self._truncate, sep_index_vocab=sep_index_vocab), field_name='words', | |||
new_field_name='words', num_proc=self.num_proc) | |||
# truncate之后需要更新seq_len | |||
dataset.add_seq_len(field_name='words') | |||
@@ -1,6 +1,7 @@ | |||
r"""undocumented""" | |||
import os | |||
import numpy as np | |||
from functools import partial | |||
from .pipe import Pipe | |||
from .utils import _drop_empty_instance | |||
@@ -25,7 +26,7 @@ class ExtCNNDMPipe(Pipe): | |||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | |||
""" | |||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False, num_proc=0): | |||
r""" | |||
:param vocab_size: int, 词表大小 | |||
@@ -39,6 +40,7 @@ class ExtCNNDMPipe(Pipe): | |||
self.sent_max_len = sent_max_len | |||
self.doc_max_timesteps = doc_max_timesteps | |||
self.domain = domain | |||
self.num_proc = num_proc | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
@@ -65,18 +67,29 @@ class ExtCNNDMPipe(Pipe): | |||
error_msg = 'vocab file is not defined!' | |||
print(error_msg) | |||
raise RuntimeError(error_msg) | |||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||
data_bundle.apply_field(_lower_text, field_name='text', new_field_name='text', num_proc=self.num_proc) | |||
data_bundle.apply_field(_lower_text, field_name='summary', new_field_name='summary', num_proc=self.num_proc) | |||
data_bundle.apply_field(_split_list, field_name='text', new_field_name='text_wd', num_proc=self.num_proc) | |||
# data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
# data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
# data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name='target') | |||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') | |||
data_bundle.apply_field(partial(_pad_sent, sent_max_len=self.sent_max_len), field_name="text_wd", | |||
new_field_name="words", num_proc=self.num_proc) | |||
# data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') | |||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | |||
# pad document | |||
data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') | |||
data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') | |||
data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') | |||
data_bundle.apply_field(partial(_pad_doc, sent_max_len=self.sent_max_len, doc_max_timesteps=self.doc_max_timesteps), | |||
field_name="words", new_field_name="words", num_proc=self.num_proc) | |||
data_bundle.apply_field(partial(_sent_mask, doc_max_timesteps=self.doc_max_timesteps), field_name="words", | |||
new_field_name="seq_len", num_proc=self.num_proc) | |||
data_bundle.apply_field(partial(_pad_label, doc_max_timesteps=self.doc_max_timesteps), field_name="target", | |||
new_field_name="target", num_proc=self.num_proc) | |||
# data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') | |||
# data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') | |||
# data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') | |||
data_bundle = _drop_empty_instance(data_bundle, "label") | |||
@@ -12,14 +12,24 @@ class TestClassificationPipe: | |||
def test_process_from_file(self): | |||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | |||
print(pipe) | |||
data_bundle = pipe(tokenizer='raw').process_from_file() | |||
data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file() | |||
print(data_bundle) | |||
def test_process_from_file_proc(self, num_proc=2): | |||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | |||
print(pipe) | |||
data_bundle = pipe(tokenizer='raw', num_proc=num_proc).process_from_file() | |||
print(data_bundle) | |||
class TestRunPipe: | |||
def test_load(self): | |||
for pipe in [IMDBPipe]: | |||
data_bundle = pipe(tokenizer='raw').process_from_file('tests/data_for_tests/io/imdb') | |||
data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file('tests/data_for_tests/io/imdb') | |||
print(data_bundle) | |||
def test_load_proc(self): | |||
for pipe in [IMDBPipe]: | |||
data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file('tests/data_for_tests/io/imdb') | |||
print(data_bundle) | |||
@@ -31,7 +41,7 @@ class TestCNClassificationPipe: | |||
print(data_bundle) | |||
@pytest.mark.skipif('download' not in os.environ, reason="Skip download") | |||
# @pytest.mark.skipif('download' not in os.environ, reason="Skip download") | |||
class TestRunClassificationPipe: | |||
def test_process_from_file(self): | |||
data_set_dict = { | |||
@@ -71,9 +81,9 @@ class TestRunClassificationPipe: | |||
path, pipe, data_set, vocab, warns = v | |||
if 'Chn' not in k: | |||
if warns: | |||
data_bundle = pipe(tokenizer='raw').process_from_file(path) | |||
data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path) | |||
else: | |||
data_bundle = pipe(tokenizer='raw').process_from_file(path) | |||
data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path) | |||
else: | |||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) | |||
@@ -87,3 +97,61 @@ class TestRunClassificationPipe: | |||
for name, vocabs in data_bundle.iter_vocabs(): | |||
assert(name in vocab.keys()) | |||
assert(vocab[name] == len(vocabs)) | |||
def test_process_from_file_proc(self): | |||
data_set_dict = { | |||
'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, | |||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2}, | |||
False), | |||
'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullPipe, | |||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5}, | |||
False), | |||
'sst-2': ('tests/data_for_tests/io/SST-2', SST2Pipe, | |||
{'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2}, | |||
True), | |||
'sst': ('tests/data_for_tests/io/SST', SSTPipe, | |||
{'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5}, | |||
False), | |||
'imdb': ('tests/data_for_tests/io/imdb', IMDBPipe, | |||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2}, | |||
False), | |||
'ag': ('tests/data_for_tests/io/ag', AGsNewsPipe, | |||
{'train': 4, 'test': 5}, {'words': 257, 'target': 4}, | |||
False), | |||
'dbpedia': ('tests/data_for_tests/io/dbpedia', DBPediaPipe, | |||
{'train': 14, 'test': 5}, {'words': 496, 'target': 14}, | |||
False), | |||
'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, | |||
{'train': 6, 'dev': 6, 'test': 6}, | |||
{'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2}, | |||
False), | |||
'Chn-THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsPipe, | |||
{'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9}, | |||
False), | |||
'Chn-WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, | |||
{'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2}, | |||
False), | |||
} | |||
for k, v in data_set_dict.items(): | |||
path, pipe, data_set, vocab, warns = v | |||
if 'Chn' not in k: | |||
if warns: | |||
data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path) | |||
else: | |||
data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path) | |||
else: | |||
# if k == 'ChnSentiCorp': | |||
# data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) | |||
# else: | |||
data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(path) | |||
assert(isinstance(data_bundle, DataBundle)) | |||
assert(len(data_set) == data_bundle.num_dataset) | |||
for name, dataset in data_bundle.iter_datasets(): | |||
assert(name in data_set.keys()) | |||
assert(data_set[name] == len(dataset)) | |||
assert(len(vocab) == data_bundle.num_vocab) | |||
for name, vocabs in data_bundle.iter_vocabs(): | |||
assert(name in vocab.keys()) | |||
assert(vocab[name] == len(vocabs)) |
@@ -22,6 +22,12 @@ class TestRunPipe: | |||
data_bundle = pipe().process_from_file('tests/data_for_tests/conll_2003_example.txt') | |||
print(data_bundle) | |||
def test_conll2003_proc(self): | |||
for pipe in [Conll2003Pipe, Conll2003NERPipe]: | |||
print(pipe) | |||
data_bundle = pipe(num_proc=2).process_from_file('tests/data_for_tests/conll_2003_example.txt') | |||
print(data_bundle) | |||
class TestNERPipe: | |||
def test_process_from_file(self): | |||
@@ -37,12 +43,33 @@ class TestNERPipe: | |||
data_bundle = pipe(encoding_type='bioes').process_from_file(f'tests/data_for_tests/io/{k}') | |||
print(data_bundle) | |||
def test_process_from_file_proc(self): | |||
data_dict = { | |||
'weibo_NER': WeiboNERPipe, | |||
'peopledaily': PeopleDailyPipe, | |||
'MSRA_NER': MsraNERPipe, | |||
} | |||
for k, v in data_dict.items(): | |||
pipe = v | |||
data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(f'tests/data_for_tests/io/{k}') | |||
print(data_bundle) | |||
data_bundle = pipe(encoding_type='bioes', num_proc=2).process_from_file(f'tests/data_for_tests/io/{k}') | |||
print(data_bundle) | |||
class TestConll2003Pipe: | |||
def test_conll(self): | |||
data_bundle = Conll2003Pipe().process_from_file('tests/data_for_tests/io/conll2003') | |||
print(data_bundle) | |||
def test_conll_proc(self): | |||
data_bundle = Conll2003Pipe(num_proc=2).process_from_file('tests/data_for_tests/io/conll2003') | |||
print(data_bundle) | |||
def test_OntoNotes(self): | |||
data_bundle = OntoNotesNERPipe().process_from_file('tests/data_for_tests/io/OntoNotes') | |||
print(data_bundle) | |||
def test_OntoNotes_proc(self): | |||
data_bundle = OntoNotesNERPipe(num_proc=2).process_from_file('tests/data_for_tests/io/OntoNotes') | |||
print(data_bundle) |
@@ -28,12 +28,25 @@ class TestRunCWSPipe: | |||
def test_process_from_file(self): | |||
dataset_names = ['msra', 'cityu', 'as', 'pku'] | |||
for dataset_name in dataset_names: | |||
data_bundle = CWSPipe(bigrams=True, trigrams=True).\ | |||
data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=0).\ | |||
process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | |||
print(data_bundle) | |||
def test_replace_number(self): | |||
data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True).\ | |||
data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=0).\ | |||
process_from_file(f'tests/data_for_tests/io/cws_pku') | |||
for word in ['<', '>', '<NUM>']: | |||
assert(data_bundle.get_vocab('chars').to_index(word) != 1) | |||
def test_process_from_file_proc(self): | |||
dataset_names = ['msra', 'cityu', 'as', 'pku'] | |||
for dataset_name in dataset_names: | |||
data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=2).\ | |||
process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | |||
print(data_bundle) | |||
def test_replace_number_proc(self): | |||
data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=2).\ | |||
process_from_file(f'tests/data_for_tests/io/cws_pku') | |||
for word in ['<', '>', '<NUM>']: | |||
assert(data_bundle.get_vocab('chars').to_index(word) != 1) |
@@ -69,6 +69,47 @@ class TestRunMatchingPipe: | |||
name, vocabs = y | |||
assert(x + 1 if name == 'words' else x == len(vocabs)) | |||
def test_load_proc(self): | |||
data_set_dict = { | |||
'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True), | |||
'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), | |||
'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | |||
'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | |||
'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | |||
'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), | |||
'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), | |||
} | |||
for k, v in data_set_dict.items(): | |||
path, pipe1, pipe2, data_set, vocab, warns = v | |||
if warns: | |||
data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path) | |||
data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path) | |||
else: | |||
data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path) | |||
data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path) | |||
assert (isinstance(data_bundle1, DataBundle)) | |||
assert (len(data_set) == data_bundle1.num_dataset) | |||
print(k) | |||
print(data_bundle1) | |||
print(data_bundle2) | |||
for x, y in zip(data_set, data_bundle1.iter_datasets()): | |||
name, dataset = y | |||
assert (x == len(dataset)) | |||
assert (len(data_set) == data_bundle2.num_dataset) | |||
for x, y in zip(data_set, data_bundle2.iter_datasets()): | |||
name, dataset = y | |||
assert (x == len(dataset)) | |||
assert (len(vocab) == data_bundle1.num_vocab) | |||
for x, y in zip(vocab, data_bundle1.iter_vocabs()): | |||
name, vocabs = y | |||
assert (x == len(vocabs)) | |||
assert (len(vocab) == data_bundle2.num_vocab) | |||
for x, y in zip(vocab, data_bundle1.iter_vocabs()): | |||
name, vocabs = y | |||
assert (x + 1 if name == 'words' else x == len(vocabs)) | |||
@pytest.mark.skipif('download' not in os.environ, reason="Skip download") | |||
def test_spacy(self): | |||
data_set_dict = { | |||
@@ -69,3 +69,45 @@ class TestRunExtCNNDMPipe: | |||
db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||
assert(isinstance(db5, DataBundle)) | |||
def test_load_proc(self): | |||
data_dir = 'tests/data_for_tests/io/cnndm' | |||
vocab_size = 100000 | |||
VOCAL_FILE = 'tests/data_for_tests/io/cnndm/vocab' | |||
sent_max_len = 100 | |||
doc_max_timesteps = 50 | |||
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | |||
vocab_path=VOCAL_FILE, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
vocab_path=VOCAL_FILE, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps, | |||
domain=True, num_proc=2) | |||
db = dbPipe.process_from_file(data_dir) | |||
db2 = dbPipe2.process_from_file(data_dir) | |||
assert(isinstance(db, DataBundle)) | |||
assert(isinstance(db2, DataBundle)) | |||
dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps, | |||
domain=True, num_proc=2) | |||
db3 = dbPipe3.process_from_file(data_dir) | |||
assert(isinstance(db3, DataBundle)) | |||
with pytest.raises(RuntimeError): | |||
dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||
db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||
dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size, | |||
vocab_path=VOCAL_FILE, | |||
sent_max_len=sent_max_len, | |||
doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||
db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||
assert(isinstance(db5, DataBundle)) | |||