| @@ -96,7 +96,7 @@ class TorchDataLoader(DataLoader): | |||||
| """ | """ | ||||
| def __init__(self, dataset, batch_size: int = 16, | def __init__(self, dataset, batch_size: int = 16, | ||||
| shuffle: bool = False, sampler = None, batch_sampler = None, | |||||
| shuffle: bool = False, sampler=None, batch_sampler=None, | |||||
| num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | ||||
| pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
| timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
| @@ -37,10 +37,11 @@ from fastNLP.core.log import logger | |||||
| class CLSBasePipe(Pipe): | class CLSBasePipe(Pipe): | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en', num_proc=0): | |||||
| super().__init__() | super().__init__() | ||||
| self.lower = lower | self.lower = lower | ||||
| self.tokenizer = get_tokenizer(tokenizer, lang=lang) | self.tokenizer = get_tokenizer(tokenizer, lang=lang) | ||||
| self.num_proc = num_proc | |||||
| def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | ||||
| r""" | r""" | ||||
| @@ -53,7 +54,8 @@ class CLSBasePipe(Pipe): | |||||
| """ | """ | ||||
| new_field_name = new_field_name or field_name | new_field_name = new_field_name or field_name | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||||
| dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, | |||||
| num_proc=self.num_proc) | |||||
| return data_bundle | return data_bundle | ||||
| @@ -117,7 +119,7 @@ class YelpFullPipe(CLSBasePipe): | |||||
| """ | """ | ||||
| def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy', num_proc=0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
| @@ -125,7 +127,7 @@ class YelpFullPipe(CLSBasePipe): | |||||
| 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | ||||
| :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
| """ | """ | ||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | ||||
| self.granularity = granularity | self.granularity = granularity | ||||
| @@ -191,13 +193,13 @@ class YelpPolarityPipe(CLSBasePipe): | |||||
| """ | """ | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
| :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
| """ | """ | ||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| r""" | r""" | ||||
| @@ -233,13 +235,13 @@ class AGsNewsPipe(CLSBasePipe): | |||||
| """ | """ | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
| :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
| """ | """ | ||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| r""" | r""" | ||||
| @@ -274,13 +276,13 @@ class DBPediaPipe(CLSBasePipe): | |||||
| """ | """ | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
| :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
| """ | """ | ||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| r""" | r""" | ||||
| @@ -315,7 +317,7 @@ class SSTPipe(CLSBasePipe): | |||||
| """ | """ | ||||
| def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | |||||
| def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy', num_proc=0): | |||||
| r""" | r""" | ||||
| :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | ||||
| @@ -325,7 +327,7 @@ class SSTPipe(CLSBasePipe): | |||||
| 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | ||||
| :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
| """ | """ | ||||
| super().__init__(tokenizer=tokenizer, lang='en') | |||||
| super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| self.subtree = subtree | self.subtree = subtree | ||||
| self.train_tree = train_subtree | self.train_tree = train_subtree | ||||
| self.lower = lower | self.lower = lower | ||||
| @@ -407,13 +409,13 @@ class SST2Pipe(CLSBasePipe): | |||||
| """ | """ | ||||
| def __init__(self, lower=False, tokenizer='raw'): | |||||
| def __init__(self, lower=False, tokenizer='raw', num_proc=0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
| :param str tokenizer: 使用哪种tokenize方式将数据切成单词。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。 | ||||
| """ | """ | ||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||||
| super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| r""" | r""" | ||||
| @@ -452,13 +454,13 @@ class IMDBPipe(CLSBasePipe): | |||||
| """ | """ | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
| :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
| """ | """ | ||||
| super().__init__(tokenizer=tokenizer, lang='en') | |||||
| super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| self.lower = lower | self.lower = lower | ||||
| def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
| @@ -483,7 +485,7 @@ class IMDBPipe(CLSBasePipe): | |||||
| return raw_words | return raw_words | ||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words') | |||||
| dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words', num_proc=self.num_proc) | |||||
| data_bundle = super().process(data_bundle) | data_bundle = super().process(data_bundle) | ||||
| @@ -527,7 +529,7 @@ class ChnSentiCorpPipe(Pipe): | |||||
| """ | """ | ||||
| def __init__(self, bigrams=False, trigrams=False): | |||||
| def __init__(self, bigrams=False, trigrams=False, num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | ||||
| @@ -541,10 +543,11 @@ class ChnSentiCorpPipe(Pipe): | |||||
| self.bigrams = bigrams | self.bigrams = bigrams | ||||
| self.trigrams = trigrams | self.trigrams = trigrams | ||||
| self.num_proc = num_proc | |||||
| def _tokenize(self, data_bundle): | def _tokenize(self, data_bundle): | ||||
| r""" | r""" | ||||
| 将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | |||||
| 将 DataSet 中的"复旦大学"拆分为 ["复", "旦", "大", "学"] . 未来可以通过扩展这个函数实现分词。 | |||||
| :param data_bundle: | :param data_bundle: | ||||
| :return: | :return: | ||||
| @@ -571,24 +574,26 @@ class ChnSentiCorpPipe(Pipe): | |||||
| data_bundle = self._tokenize(data_bundle) | data_bundle = self._tokenize(data_bundle) | ||||
| input_field_names = ['chars'] | input_field_names = ['chars'] | ||||
| def bigrams(chars): | |||||
| return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
| def trigrams(chars): | |||||
| return [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
| if self.bigrams: | if self.bigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
| field_name='chars', new_field_name='bigrams') | |||||
| dataset.apply_field(bigrams,field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
| input_field_names.append('bigrams') | input_field_names.append('bigrams') | ||||
| if self.trigrams: | if self.trigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
| field_name='chars', new_field_name='trigrams') | |||||
| dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
| input_field_names.append('trigrams') | input_field_names.append('trigrams') | ||||
| # index | # index | ||||
| _indexize(data_bundle, input_field_names, 'target') | _indexize(data_bundle, input_field_names, 'target') | ||||
| input_fields = ['target', 'seq_len'] + input_field_names | |||||
| target_fields = ['target'] | |||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.add_seq_len('chars') | dataset.add_seq_len('chars') | ||||
| @@ -637,8 +642,8 @@ class THUCNewsPipe(CLSBasePipe): | |||||
| data_bundle.get_vocab('trigrams')获取. | data_bundle.get_vocab('trigrams')获取. | ||||
| """ | """ | ||||
| def __init__(self, bigrams=False, trigrams=False): | |||||
| super().__init__() | |||||
| def __init__(self, bigrams=False, trigrams=False, num_proc=0): | |||||
| super().__init__(num_proc=num_proc) | |||||
| self.bigrams = bigrams | self.bigrams = bigrams | ||||
| self.trigrams = trigrams | self.trigrams = trigrams | ||||
| @@ -653,7 +658,7 @@ class THUCNewsPipe(CLSBasePipe): | |||||
| def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | ||||
| new_field_name = new_field_name or field_name | new_field_name = new_field_name or field_name | ||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) | |||||
| dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||||
| return data_bundle | return data_bundle | ||||
| def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
| @@ -680,17 +685,21 @@ class THUCNewsPipe(CLSBasePipe): | |||||
| input_field_names = ['chars'] | input_field_names = ['chars'] | ||||
| def bigrams(chars): | |||||
| return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
| def trigrams(chars): | |||||
| return [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
| # n-grams | # n-grams | ||||
| if self.bigrams: | if self.bigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
| field_name='chars', new_field_name='bigrams') | |||||
| dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
| input_field_names.append('bigrams') | input_field_names.append('bigrams') | ||||
| if self.trigrams: | if self.trigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
| field_name='chars', new_field_name='trigrams') | |||||
| dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
| input_field_names.append('trigrams') | input_field_names.append('trigrams') | ||||
| # index | # index | ||||
| @@ -700,9 +709,6 @@ class THUCNewsPipe(CLSBasePipe): | |||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | ||||
| input_fields = ['target', 'seq_len'] + input_field_names | |||||
| target_fields = ['target'] | |||||
| return data_bundle | return data_bundle | ||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| @@ -746,8 +752,8 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
| data_bundle.get_vocab('trigrams')获取. | data_bundle.get_vocab('trigrams')获取. | ||||
| """ | """ | ||||
| def __init__(self, bigrams=False, trigrams=False): | |||||
| super().__init__() | |||||
| def __init__(self, bigrams=False, trigrams=False, num_proc=0): | |||||
| super().__init__(num_proc=num_proc) | |||||
| self.bigrams = bigrams | self.bigrams = bigrams | ||||
| self.trigrams = trigrams | self.trigrams = trigrams | ||||
| @@ -758,7 +764,8 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
| def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | ||||
| new_field_name = new_field_name or field_name | new_field_name = new_field_name or field_name | ||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) | |||||
| dataset.apply_field(self._chracter_split, field_name=field_name, | |||||
| new_field_name=new_field_name, num_proc=self.num_proc) | |||||
| return data_bundle | return data_bundle | ||||
| def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
| @@ -779,20 +786,19 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
| # CWS(tokenize) | # CWS(tokenize) | ||||
| data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') | data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') | ||||
| input_field_names = ['chars'] | |||||
| def bigrams(chars): | |||||
| return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
| def trigrams(chars): | |||||
| return [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
| # n-grams | # n-grams | ||||
| if self.bigrams: | if self.bigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
| field_name='chars', new_field_name='bigrams') | |||||
| input_field_names.append('bigrams') | |||||
| dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
| if self.trigrams: | if self.trigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
| field_name='chars', new_field_name='trigrams') | |||||
| input_field_names.append('trigrams') | |||||
| dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
| # index | # index | ||||
| data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') | data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') | ||||
| @@ -801,9 +807,6 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | ||||
| input_fields = ['target', 'seq_len'] + input_field_names | |||||
| target_fields = ['target'] | |||||
| return data_bundle | return data_bundle | ||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| @@ -817,13 +820,13 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
| return data_bundle | return data_bundle | ||||
| class MRPipe(CLSBasePipe): | class MRPipe(CLSBasePipe): | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
| :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
| """ | """ | ||||
| super().__init__(tokenizer=tokenizer, lang='en') | |||||
| super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| self.lower = lower | self.lower = lower | ||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| @@ -840,13 +843,13 @@ class MRPipe(CLSBasePipe): | |||||
| class R8Pipe(CLSBasePipe): | class R8Pipe(CLSBasePipe): | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc = 0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
| :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
| """ | """ | ||||
| super().__init__(tokenizer=tokenizer, lang='en') | |||||
| super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| self.lower = lower | self.lower = lower | ||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| @@ -863,13 +866,13 @@ class R8Pipe(CLSBasePipe): | |||||
| class R52Pipe(CLSBasePipe): | class R52Pipe(CLSBasePipe): | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
| :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
| """ | """ | ||||
| super().__init__(tokenizer=tokenizer, lang='en') | |||||
| super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| self.lower = lower | self.lower = lower | ||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| @@ -886,13 +889,13 @@ class R52Pipe(CLSBasePipe): | |||||
| class OhsumedPipe(CLSBasePipe): | class OhsumedPipe(CLSBasePipe): | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
| :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
| """ | """ | ||||
| super().__init__(tokenizer=tokenizer, lang='en') | |||||
| super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| self.lower = lower | self.lower = lower | ||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| @@ -909,13 +912,13 @@ class OhsumedPipe(CLSBasePipe): | |||||
| class NG20Pipe(CLSBasePipe): | class NG20Pipe(CLSBasePipe): | ||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
| def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
| :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
| """ | """ | ||||
| super().__init__(tokenizer=tokenizer, lang='en') | |||||
| super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) | |||||
| self.lower = lower | self.lower = lower | ||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| @@ -30,7 +30,7 @@ class _NERPipe(Pipe): | |||||
| target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | ||||
| """ | """ | ||||
| def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||||
| def __init__(self, encoding_type: str = 'bio', lower: bool = False, num_proc=0): | |||||
| r""" | r""" | ||||
| :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
| @@ -39,10 +39,14 @@ class _NERPipe(Pipe): | |||||
| if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
| self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
| elif encoding_type == 'bioes': | elif encoding_type == 'bioes': | ||||
| self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
| def func(words): | |||||
| return iob2bioes(iob2(words)) | |||||
| # self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
| self.convert_tag = func | |||||
| else: | else: | ||||
| raise ValueError("encoding_type only supports `bio` and `bioes`.") | raise ValueError("encoding_type only supports `bio` and `bioes`.") | ||||
| self.lower = lower | self.lower = lower | ||||
| self.num_proc = num_proc | |||||
| def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
| r""" | r""" | ||||
| @@ -60,16 +64,13 @@ class _NERPipe(Pipe): | |||||
| """ | """ | ||||
| # 转换tag | # 转换tag | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') | |||||
| dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target', num_proc=self.num_proc) | |||||
| _add_words_field(data_bundle, lower=self.lower) | _add_words_field(data_bundle, lower=self.lower) | ||||
| # index | # index | ||||
| _indexize(data_bundle) | _indexize(data_bundle) | ||||
| input_fields = ['target', 'words', 'seq_len'] | |||||
| target_fields = ['target', 'seq_len'] | |||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.add_seq_len('words') | dataset.add_seq_len('words') | ||||
| @@ -144,7 +145,7 @@ class Conll2003Pipe(Pipe): | |||||
| """ | """ | ||||
| def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | |||||
| def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param str chunk_encoding_type: 支持bioes, bio。 | :param str chunk_encoding_type: 支持bioes, bio。 | ||||
| @@ -154,16 +155,23 @@ class Conll2003Pipe(Pipe): | |||||
| if chunk_encoding_type == 'bio': | if chunk_encoding_type == 'bio': | ||||
| self.chunk_convert_tag = iob2 | self.chunk_convert_tag = iob2 | ||||
| elif chunk_encoding_type == 'bioes': | elif chunk_encoding_type == 'bioes': | ||||
| self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||||
| def func1(tags): | |||||
| return iob2bioes(iob2(tags)) | |||||
| # self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||||
| self.chunk_convert_tag = func1 | |||||
| else: | else: | ||||
| raise ValueError("chunk_encoding_type only supports `bio` and `bioes`.") | raise ValueError("chunk_encoding_type only supports `bio` and `bioes`.") | ||||
| if ner_encoding_type == 'bio': | if ner_encoding_type == 'bio': | ||||
| self.ner_convert_tag = iob2 | self.ner_convert_tag = iob2 | ||||
| elif ner_encoding_type == 'bioes': | elif ner_encoding_type == 'bioes': | ||||
| self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||||
| def func2(tags): | |||||
| return iob2bioes(iob2(tags)) | |||||
| # self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||||
| self.ner_convert_tag = func2 | |||||
| else: | else: | ||||
| raise ValueError("ner_encoding_type only supports `bio` and `bioes`.") | raise ValueError("ner_encoding_type only supports `bio` and `bioes`.") | ||||
| self.lower = lower | self.lower = lower | ||||
| self.num_proc = num_proc | |||||
| def process(self, data_bundle) -> DataBundle: | def process(self, data_bundle) -> DataBundle: | ||||
| r""" | r""" | ||||
| @@ -182,8 +190,8 @@ class Conll2003Pipe(Pipe): | |||||
| # 转换tag | # 转换tag | ||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.drop(lambda x: "-DOCSTART-" in x['raw_words']) | dataset.drop(lambda x: "-DOCSTART-" in x['raw_words']) | ||||
| dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') | |||||
| dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') | |||||
| dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk', num_proc=self.num_proc) | |||||
| dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner', num_proc=self.num_proc) | |||||
| _add_words_field(data_bundle, lower=self.lower) | _add_words_field(data_bundle, lower=self.lower) | ||||
| @@ -194,10 +202,7 @@ class Conll2003Pipe(Pipe): | |||||
| tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | ||||
| tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | ||||
| data_bundle.set_vocab(tgt_vocab, 'chunk') | data_bundle.set_vocab(tgt_vocab, 'chunk') | ||||
| input_fields = ['words', 'seq_len'] | |||||
| target_fields = ['pos', 'ner', 'chunk', 'seq_len'] | |||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.add_seq_len('words') | dataset.add_seq_len('words') | ||||
| @@ -256,7 +261,7 @@ class _CNNERPipe(Pipe): | |||||
| """ | """ | ||||
| def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | |||||
| def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False, num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
| @@ -270,12 +275,16 @@ class _CNNERPipe(Pipe): | |||||
| if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
| self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
| elif encoding_type == 'bioes': | elif encoding_type == 'bioes': | ||||
| self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
| def func(words): | |||||
| return iob2bioes(iob2(words)) | |||||
| # self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
| self.convert_tag = func | |||||
| else: | else: | ||||
| raise ValueError("encoding_type only supports `bio` and `bioes`.") | raise ValueError("encoding_type only supports `bio` and `bioes`.") | ||||
| self.bigrams = bigrams | self.bigrams = bigrams | ||||
| self.trigrams = trigrams | self.trigrams = trigrams | ||||
| self.num_proc = num_proc | |||||
| def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
| r""" | r""" | ||||
| @@ -296,29 +305,31 @@ class _CNNERPipe(Pipe): | |||||
| """ | """ | ||||
| # 转换tag | # 转换tag | ||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') | |||||
| dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target', num_proc=self.num_proc) | |||||
| _add_chars_field(data_bundle, lower=False) | _add_chars_field(data_bundle, lower=False) | ||||
| input_field_names = ['chars'] | input_field_names = ['chars'] | ||||
| def bigrams(chars): | |||||
| return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
| def trigrams(chars): | |||||
| return [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
| if self.bigrams: | if self.bigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
| field_name='chars', new_field_name='bigrams') | |||||
| dataset.apply_field(bigrams, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
| input_field_names.append('bigrams') | input_field_names.append('bigrams') | ||||
| if self.trigrams: | if self.trigrams: | ||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
| field_name='chars', new_field_name='trigrams') | |||||
| dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
| input_field_names.append('trigrams') | input_field_names.append('trigrams') | ||||
| # index | # index | ||||
| _indexize(data_bundle, input_field_names, 'target') | _indexize(data_bundle, input_field_names, 'target') | ||||
| input_fields = ['target', 'seq_len'] + input_field_names | |||||
| target_fields = ['target', 'seq_len'] | |||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.add_seq_len('chars') | dataset.add_seq_len('chars') | ||||
| @@ -157,7 +157,8 @@ class CWSPipe(Pipe): | |||||
| """ | """ | ||||
| def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | |||||
| def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, | |||||
| bigrams=False, trigrams=False, num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | ||||
| @@ -176,6 +177,7 @@ class CWSPipe(Pipe): | |||||
| self.bigrams = bigrams | self.bigrams = bigrams | ||||
| self.trigrams = trigrams | self.trigrams = trigrams | ||||
| self.replace_num_alpha = replace_num_alpha | self.replace_num_alpha = replace_num_alpha | ||||
| self.num_proc = num_proc | |||||
| def _tokenize(self, data_bundle): | def _tokenize(self, data_bundle): | ||||
| r""" | r""" | ||||
| @@ -213,7 +215,7 @@ class CWSPipe(Pipe): | |||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(split_word_into_chars, field_name='chars', | dataset.apply_field(split_word_into_chars, field_name='chars', | ||||
| new_field_name='chars') | |||||
| new_field_name='chars', num_proc=self.num_proc) | |||||
| return data_bundle | return data_bundle | ||||
| def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
| @@ -233,33 +235,40 @@ class CWSPipe(Pipe): | |||||
| data_bundle.copy_field('raw_words', 'chars') | data_bundle.copy_field('raw_words', 'chars') | ||||
| if self.replace_num_alpha: | if self.replace_num_alpha: | ||||
| data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars') | |||||
| data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars') | |||||
| data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars', num_proc=self.num_proc) | |||||
| data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars', num_proc=self.num_proc) | |||||
| self._tokenize(data_bundle) | self._tokenize(data_bundle) | ||||
| def func1(chars): | |||||
| return self.word_lens_to_tags(map(len, chars)) | |||||
| def func2(chars): | |||||
| return list(chain(*chars)) | |||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name='chars', | |||||
| new_field_name='target') | |||||
| dataset.apply_field(lambda chars: list(chain(*chars)), field_name='chars', | |||||
| new_field_name='chars') | |||||
| dataset.apply_field(func1, field_name='chars', new_field_name='target', num_proc=self.num_proc) | |||||
| dataset.apply_field(func2, field_name='chars', new_field_name='chars', num_proc=self.num_proc) | |||||
| input_field_names = ['chars'] | input_field_names = ['chars'] | ||||
| def bigram(chars): | |||||
| return [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])] | |||||
| def trigrams(chars): | |||||
| return [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)] | |||||
| if self.bigrams: | if self.bigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
| field_name='chars', new_field_name='bigrams') | |||||
| dataset.apply_field(bigram, field_name='chars', new_field_name='bigrams', num_proc=self.num_proc) | |||||
| input_field_names.append('bigrams') | input_field_names.append('bigrams') | ||||
| if self.trigrams: | if self.trigrams: | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
| zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
| field_name='chars', new_field_name='trigrams') | |||||
| dataset.apply_field(trigrams, field_name='chars', new_field_name='trigrams', num_proc=self.num_proc) | |||||
| input_field_names.append('trigrams') | input_field_names.append('trigrams') | ||||
| _indexize(data_bundle, input_field_names, 'target') | _indexize(data_bundle, input_field_names, 'target') | ||||
| input_fields = ['target', 'seq_len'] + input_field_names | |||||
| target_fields = ['target', 'seq_len'] | |||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.add_seq_len('chars') | dataset.add_seq_len('chars') | ||||
| @@ -23,6 +23,7 @@ __all__ = [ | |||||
| "GranularizePipe", | "GranularizePipe", | ||||
| "MachingTruncatePipe", | "MachingTruncatePipe", | ||||
| ] | ] | ||||
| from functools import partial | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from .pipe import Pipe | from .pipe import Pipe | ||||
| @@ -63,7 +64,7 @@ class MatchingBertPipe(Pipe): | |||||
| """ | """ | ||||
| def __init__(self, lower=False, tokenizer: str = 'raw'): | |||||
| def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否将word小写化。 | :param bool lower: 是否将word小写化。 | ||||
| @@ -73,6 +74,7 @@ class MatchingBertPipe(Pipe): | |||||
| self.lower = bool(lower) | self.lower = bool(lower) | ||||
| self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | ||||
| self.num_proc = num_proc | |||||
| def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
| r""" | r""" | ||||
| @@ -84,8 +86,7 @@ class MatchingBertPipe(Pipe): | |||||
| """ | """ | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| for field_name, new_field_name in zip(field_names, new_field_names): | for field_name, new_field_name in zip(field_names, new_field_names): | ||||
| dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||||
| new_field_name=new_field_name) | |||||
| dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||||
| return data_bundle | return data_bundle | ||||
| def process(self, data_bundle): | def process(self, data_bundle): | ||||
| @@ -124,8 +125,8 @@ class MatchingBertPipe(Pipe): | |||||
| words = words0 + ['[SEP]'] + words1 | words = words0 + ['[SEP]'] + words1 | ||||
| return words | return words | ||||
| for name, dataset in data_bundle.datasets.items(): | |||||
| dataset.apply(concat, new_field_name='words') | |||||
| for name, dataset in data_bundle.iter_datasets(): | |||||
| dataset.apply(concat, new_field_name='words', num_proc=self.num_proc) | |||||
| dataset.delete_field('words1') | dataset.delete_field('words1') | ||||
| dataset.delete_field('words2') | dataset.delete_field('words2') | ||||
| @@ -155,10 +156,7 @@ class MatchingBertPipe(Pipe): | |||||
| data_bundle.set_vocab(word_vocab, 'words') | data_bundle.set_vocab(word_vocab, 'words') | ||||
| data_bundle.set_vocab(target_vocab, 'target') | data_bundle.set_vocab(target_vocab, 'target') | ||||
| input_fields = ['words', 'seq_len'] | |||||
| target_fields = ['target'] | |||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| dataset.add_seq_len('words') | dataset.add_seq_len('words') | ||||
| @@ -223,7 +221,7 @@ class MatchingPipe(Pipe): | |||||
| """ | """ | ||||
| def __init__(self, lower=False, tokenizer: str = 'raw'): | |||||
| def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): | |||||
| r""" | r""" | ||||
| :param bool lower: 是否将所有raw_words转为小写。 | :param bool lower: 是否将所有raw_words转为小写。 | ||||
| @@ -233,6 +231,7 @@ class MatchingPipe(Pipe): | |||||
| self.lower = bool(lower) | self.lower = bool(lower) | ||||
| self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | ||||
| self.num_proc = num_proc | |||||
| def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
| r""" | r""" | ||||
| @@ -244,8 +243,7 @@ class MatchingPipe(Pipe): | |||||
| """ | """ | ||||
| for name, dataset in data_bundle.iter_datasets(): | for name, dataset in data_bundle.iter_datasets(): | ||||
| for field_name, new_field_name in zip(field_names, new_field_names): | for field_name, new_field_name in zip(field_names, new_field_names): | ||||
| dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||||
| new_field_name=new_field_name) | |||||
| dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) | |||||
| return data_bundle | return data_bundle | ||||
| def process(self, data_bundle): | def process(self, data_bundle): | ||||
| @@ -300,10 +298,7 @@ class MatchingPipe(Pipe): | |||||
| data_bundle.set_vocab(word_vocab, 'words1') | data_bundle.set_vocab(word_vocab, 'words1') | ||||
| data_bundle.set_vocab(target_vocab, 'target') | data_bundle.set_vocab(target_vocab, 'target') | ||||
| input_fields = ['words1', 'words2', 'seq_len1', 'seq_len2'] | |||||
| target_fields = ['target'] | |||||
| for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
| dataset.add_seq_len('words1', 'seq_len1') | dataset.add_seq_len('words1', 'seq_len1') | ||||
| dataset.add_seq_len('words2', 'seq_len2') | dataset.add_seq_len('words2', 'seq_len2') | ||||
| @@ -342,8 +337,8 @@ class MNLIPipe(MatchingPipe): | |||||
| class LCQMCPipe(MatchingPipe): | class LCQMCPipe(MatchingPipe): | ||||
| def __init__(self, tokenizer='cn=char'): | |||||
| super().__init__(tokenizer=tokenizer) | |||||
| def __init__(self, tokenizer='cn=char', num_proc=0): | |||||
| super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| data_bundle = LCQMCLoader().load(paths) | data_bundle = LCQMCLoader().load(paths) | ||||
| @@ -354,8 +349,8 @@ class LCQMCPipe(MatchingPipe): | |||||
| class CNXNLIPipe(MatchingPipe): | class CNXNLIPipe(MatchingPipe): | ||||
| def __init__(self, tokenizer='cn-char'): | |||||
| super().__init__(tokenizer=tokenizer) | |||||
| def __init__(self, tokenizer='cn-char', num_proc=0): | |||||
| super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| data_bundle = CNXNLILoader().load(paths) | data_bundle = CNXNLILoader().load(paths) | ||||
| @@ -367,8 +362,8 @@ class CNXNLIPipe(MatchingPipe): | |||||
| class BQCorpusPipe(MatchingPipe): | class BQCorpusPipe(MatchingPipe): | ||||
| def __init__(self, tokenizer='cn-char'): | |||||
| super().__init__(tokenizer=tokenizer) | |||||
| def __init__(self, tokenizer='cn-char', num_proc=0): | |||||
| super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| data_bundle = BQCorpusLoader().load(paths) | data_bundle = BQCorpusLoader().load(paths) | ||||
| @@ -379,9 +374,10 @@ class BQCorpusPipe(MatchingPipe): | |||||
| class RenamePipe(Pipe): | class RenamePipe(Pipe): | ||||
| def __init__(self, task='cn-nli'): | |||||
| def __init__(self, task='cn-nli', num_proc=0): | |||||
| super().__init__() | super().__init__() | ||||
| self.task = task | self.task = task | ||||
| self.num_proc = num_proc | |||||
| def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset | def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset | ||||
| if (self.task == 'cn-nli'): | if (self.task == 'cn-nli'): | ||||
| @@ -419,9 +415,10 @@ class RenamePipe(Pipe): | |||||
| class GranularizePipe(Pipe): | class GranularizePipe(Pipe): | ||||
| def __init__(self, task=None): | |||||
| def __init__(self, task=None, num_proc=0): | |||||
| super().__init__() | super().__init__() | ||||
| self.task = task | self.task = task | ||||
| self.num_proc = num_proc | |||||
| def _granularize(self, data_bundle, tag_map): | def _granularize(self, data_bundle, tag_map): | ||||
| r""" | r""" | ||||
| @@ -434,8 +431,7 @@ class GranularizePipe(Pipe): | |||||
| """ | """ | ||||
| for name in list(data_bundle.datasets.keys()): | for name in list(data_bundle.datasets.keys()): | ||||
| dataset = data_bundle.get_dataset(name) | dataset = data_bundle.get_dataset(name) | ||||
| dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', | |||||
| new_field_name='target') | |||||
| dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', new_field_name='target') | |||||
| dataset.drop(lambda ins: ins['target'] == -100) | dataset.drop(lambda ins: ins['target'] == -100) | ||||
| data_bundle.set_dataset(dataset, name) | data_bundle.set_dataset(dataset, name) | ||||
| return data_bundle | return data_bundle | ||||
| @@ -462,8 +458,8 @@ class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len | |||||
| class LCQMCBertPipe(MatchingBertPipe): | class LCQMCBertPipe(MatchingBertPipe): | ||||
| def __init__(self, tokenizer='cn=char'): | |||||
| super().__init__(tokenizer=tokenizer) | |||||
| def __init__(self, tokenizer='cn=char', num_proc=0): | |||||
| super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| data_bundle = LCQMCLoader().load(paths) | data_bundle = LCQMCLoader().load(paths) | ||||
| @@ -475,8 +471,8 @@ class LCQMCBertPipe(MatchingBertPipe): | |||||
| class BQCorpusBertPipe(MatchingBertPipe): | class BQCorpusBertPipe(MatchingBertPipe): | ||||
| def __init__(self, tokenizer='cn-char'): | |||||
| super().__init__(tokenizer=tokenizer) | |||||
| def __init__(self, tokenizer='cn-char', num_proc=0): | |||||
| super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| data_bundle = BQCorpusLoader().load(paths) | data_bundle = BQCorpusLoader().load(paths) | ||||
| @@ -488,8 +484,8 @@ class BQCorpusBertPipe(MatchingBertPipe): | |||||
| class CNXNLIBertPipe(MatchingBertPipe): | class CNXNLIBertPipe(MatchingBertPipe): | ||||
| def __init__(self, tokenizer='cn-char'): | |||||
| super().__init__(tokenizer=tokenizer) | |||||
| def __init__(self, tokenizer='cn-char', num_proc=0): | |||||
| super().__init__(tokenizer=tokenizer, num_proc=num_proc) | |||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| data_bundle = CNXNLILoader().load(paths) | data_bundle = CNXNLILoader().load(paths) | ||||
| @@ -502,9 +498,10 @@ class CNXNLIBertPipe(MatchingBertPipe): | |||||
| class TruncateBertPipe(Pipe): | class TruncateBertPipe(Pipe): | ||||
| def __init__(self, task='cn'): | |||||
| def __init__(self, task='cn', num_proc=0): | |||||
| super().__init__() | super().__init__() | ||||
| self.task = task | self.task = task | ||||
| self.num_proc = num_proc | |||||
| def _truncate(self, sentence_index:list, sep_index_vocab): | def _truncate(self, sentence_index:list, sep_index_vocab): | ||||
| # 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index | # 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index | ||||
| @@ -528,7 +525,8 @@ class TruncateBertPipe(Pipe): | |||||
| for name in data_bundle.datasets.keys(): | for name in data_bundle.datasets.keys(): | ||||
| dataset = data_bundle.get_dataset(name) | dataset = data_bundle.get_dataset(name) | ||||
| sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') | sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') | ||||
| dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') | |||||
| dataset.apply_field(partial(self._truncate, sep_index_vocab=sep_index_vocab), field_name='words', | |||||
| new_field_name='words', num_proc=self.num_proc) | |||||
| # truncate之后需要更新seq_len | # truncate之后需要更新seq_len | ||||
| dataset.add_seq_len(field_name='words') | dataset.add_seq_len(field_name='words') | ||||
| @@ -1,6 +1,7 @@ | |||||
| r"""undocumented""" | r"""undocumented""" | ||||
| import os | import os | ||||
| import numpy as np | import numpy as np | ||||
| from functools import partial | |||||
| from .pipe import Pipe | from .pipe import Pipe | ||||
| from .utils import _drop_empty_instance | from .utils import _drop_empty_instance | ||||
| @@ -25,7 +26,7 @@ class ExtCNNDMPipe(Pipe): | |||||
| :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | ||||
| """ | """ | ||||
| def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||||
| def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False, num_proc=0): | |||||
| r""" | r""" | ||||
| :param vocab_size: int, 词表大小 | :param vocab_size: int, 词表大小 | ||||
| @@ -39,6 +40,7 @@ class ExtCNNDMPipe(Pipe): | |||||
| self.sent_max_len = sent_max_len | self.sent_max_len = sent_max_len | ||||
| self.doc_max_timesteps = doc_max_timesteps | self.doc_max_timesteps = doc_max_timesteps | ||||
| self.domain = domain | self.domain = domain | ||||
| self.num_proc = num_proc | |||||
| def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
| r""" | r""" | ||||
| @@ -65,18 +67,29 @@ class ExtCNNDMPipe(Pipe): | |||||
| error_msg = 'vocab file is not defined!' | error_msg = 'vocab file is not defined!' | ||||
| print(error_msg) | print(error_msg) | ||||
| raise RuntimeError(error_msg) | raise RuntimeError(error_msg) | ||||
| data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
| data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
| data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||||
| data_bundle.apply_field(_lower_text, field_name='text', new_field_name='text', num_proc=self.num_proc) | |||||
| data_bundle.apply_field(_lower_text, field_name='summary', new_field_name='summary', num_proc=self.num_proc) | |||||
| data_bundle.apply_field(_split_list, field_name='text', new_field_name='text_wd', num_proc=self.num_proc) | |||||
| # data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
| # data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
| # data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||||
| data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name='target') | data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name='target') | ||||
| data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') | |||||
| data_bundle.apply_field(partial(_pad_sent, sent_max_len=self.sent_max_len), field_name="text_wd", | |||||
| new_field_name="words", num_proc=self.num_proc) | |||||
| # data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') | |||||
| # db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | # db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | ||||
| # pad document | # pad document | ||||
| data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') | |||||
| data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') | |||||
| data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') | |||||
| data_bundle.apply_field(partial(_pad_doc, sent_max_len=self.sent_max_len, doc_max_timesteps=self.doc_max_timesteps), | |||||
| field_name="words", new_field_name="words", num_proc=self.num_proc) | |||||
| data_bundle.apply_field(partial(_sent_mask, doc_max_timesteps=self.doc_max_timesteps), field_name="words", | |||||
| new_field_name="seq_len", num_proc=self.num_proc) | |||||
| data_bundle.apply_field(partial(_pad_label, doc_max_timesteps=self.doc_max_timesteps), field_name="target", | |||||
| new_field_name="target", num_proc=self.num_proc) | |||||
| # data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') | |||||
| # data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') | |||||
| # data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') | |||||
| data_bundle = _drop_empty_instance(data_bundle, "label") | data_bundle = _drop_empty_instance(data_bundle, "label") | ||||
| @@ -12,14 +12,24 @@ class TestClassificationPipe: | |||||
| def test_process_from_file(self): | def test_process_from_file(self): | ||||
| for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | ||||
| print(pipe) | print(pipe) | ||||
| data_bundle = pipe(tokenizer='raw').process_from_file() | |||||
| data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file() | |||||
| print(data_bundle) | print(data_bundle) | ||||
| def test_process_from_file_proc(self, num_proc=2): | |||||
| for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | |||||
| print(pipe) | |||||
| data_bundle = pipe(tokenizer='raw', num_proc=num_proc).process_from_file() | |||||
| print(data_bundle) | |||||
| class TestRunPipe: | class TestRunPipe: | ||||
| def test_load(self): | def test_load(self): | ||||
| for pipe in [IMDBPipe]: | for pipe in [IMDBPipe]: | ||||
| data_bundle = pipe(tokenizer='raw').process_from_file('tests/data_for_tests/io/imdb') | |||||
| data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file('tests/data_for_tests/io/imdb') | |||||
| print(data_bundle) | |||||
| def test_load_proc(self): | |||||
| for pipe in [IMDBPipe]: | |||||
| data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file('tests/data_for_tests/io/imdb') | |||||
| print(data_bundle) | print(data_bundle) | ||||
| @@ -31,7 +41,7 @@ class TestCNClassificationPipe: | |||||
| print(data_bundle) | print(data_bundle) | ||||
| @pytest.mark.skipif('download' not in os.environ, reason="Skip download") | |||||
| # @pytest.mark.skipif('download' not in os.environ, reason="Skip download") | |||||
| class TestRunClassificationPipe: | class TestRunClassificationPipe: | ||||
| def test_process_from_file(self): | def test_process_from_file(self): | ||||
| data_set_dict = { | data_set_dict = { | ||||
| @@ -71,9 +81,9 @@ class TestRunClassificationPipe: | |||||
| path, pipe, data_set, vocab, warns = v | path, pipe, data_set, vocab, warns = v | ||||
| if 'Chn' not in k: | if 'Chn' not in k: | ||||
| if warns: | if warns: | ||||
| data_bundle = pipe(tokenizer='raw').process_from_file(path) | |||||
| data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path) | |||||
| else: | else: | ||||
| data_bundle = pipe(tokenizer='raw').process_from_file(path) | |||||
| data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path) | |||||
| else: | else: | ||||
| data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) | data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) | ||||
| @@ -87,3 +97,61 @@ class TestRunClassificationPipe: | |||||
| for name, vocabs in data_bundle.iter_vocabs(): | for name, vocabs in data_bundle.iter_vocabs(): | ||||
| assert(name in vocab.keys()) | assert(name in vocab.keys()) | ||||
| assert(vocab[name] == len(vocabs)) | assert(vocab[name] == len(vocabs)) | ||||
| def test_process_from_file_proc(self): | |||||
| data_set_dict = { | |||||
| 'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, | |||||
| {'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2}, | |||||
| False), | |||||
| 'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullPipe, | |||||
| {'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5}, | |||||
| False), | |||||
| 'sst-2': ('tests/data_for_tests/io/SST-2', SST2Pipe, | |||||
| {'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2}, | |||||
| True), | |||||
| 'sst': ('tests/data_for_tests/io/SST', SSTPipe, | |||||
| {'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5}, | |||||
| False), | |||||
| 'imdb': ('tests/data_for_tests/io/imdb', IMDBPipe, | |||||
| {'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2}, | |||||
| False), | |||||
| 'ag': ('tests/data_for_tests/io/ag', AGsNewsPipe, | |||||
| {'train': 4, 'test': 5}, {'words': 257, 'target': 4}, | |||||
| False), | |||||
| 'dbpedia': ('tests/data_for_tests/io/dbpedia', DBPediaPipe, | |||||
| {'train': 14, 'test': 5}, {'words': 496, 'target': 14}, | |||||
| False), | |||||
| 'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, | |||||
| {'train': 6, 'dev': 6, 'test': 6}, | |||||
| {'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2}, | |||||
| False), | |||||
| 'Chn-THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsPipe, | |||||
| {'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9}, | |||||
| False), | |||||
| 'Chn-WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, | |||||
| {'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2}, | |||||
| False), | |||||
| } | |||||
| for k, v in data_set_dict.items(): | |||||
| path, pipe, data_set, vocab, warns = v | |||||
| if 'Chn' not in k: | |||||
| if warns: | |||||
| data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
| else: | |||||
| data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
| else: | |||||
| # if k == 'ChnSentiCorp': | |||||
| # data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path) | |||||
| # else: | |||||
| data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(path) | |||||
| assert(isinstance(data_bundle, DataBundle)) | |||||
| assert(len(data_set) == data_bundle.num_dataset) | |||||
| for name, dataset in data_bundle.iter_datasets(): | |||||
| assert(name in data_set.keys()) | |||||
| assert(data_set[name] == len(dataset)) | |||||
| assert(len(vocab) == data_bundle.num_vocab) | |||||
| for name, vocabs in data_bundle.iter_vocabs(): | |||||
| assert(name in vocab.keys()) | |||||
| assert(vocab[name] == len(vocabs)) | |||||
| @@ -22,6 +22,12 @@ class TestRunPipe: | |||||
| data_bundle = pipe().process_from_file('tests/data_for_tests/conll_2003_example.txt') | data_bundle = pipe().process_from_file('tests/data_for_tests/conll_2003_example.txt') | ||||
| print(data_bundle) | print(data_bundle) | ||||
| def test_conll2003_proc(self): | |||||
| for pipe in [Conll2003Pipe, Conll2003NERPipe]: | |||||
| print(pipe) | |||||
| data_bundle = pipe(num_proc=2).process_from_file('tests/data_for_tests/conll_2003_example.txt') | |||||
| print(data_bundle) | |||||
| class TestNERPipe: | class TestNERPipe: | ||||
| def test_process_from_file(self): | def test_process_from_file(self): | ||||
| @@ -37,12 +43,33 @@ class TestNERPipe: | |||||
| data_bundle = pipe(encoding_type='bioes').process_from_file(f'tests/data_for_tests/io/{k}') | data_bundle = pipe(encoding_type='bioes').process_from_file(f'tests/data_for_tests/io/{k}') | ||||
| print(data_bundle) | print(data_bundle) | ||||
| def test_process_from_file_proc(self): | |||||
| data_dict = { | |||||
| 'weibo_NER': WeiboNERPipe, | |||||
| 'peopledaily': PeopleDailyPipe, | |||||
| 'MSRA_NER': MsraNERPipe, | |||||
| } | |||||
| for k, v in data_dict.items(): | |||||
| pipe = v | |||||
| data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(f'tests/data_for_tests/io/{k}') | |||||
| print(data_bundle) | |||||
| data_bundle = pipe(encoding_type='bioes', num_proc=2).process_from_file(f'tests/data_for_tests/io/{k}') | |||||
| print(data_bundle) | |||||
| class TestConll2003Pipe: | class TestConll2003Pipe: | ||||
| def test_conll(self): | def test_conll(self): | ||||
| data_bundle = Conll2003Pipe().process_from_file('tests/data_for_tests/io/conll2003') | data_bundle = Conll2003Pipe().process_from_file('tests/data_for_tests/io/conll2003') | ||||
| print(data_bundle) | print(data_bundle) | ||||
| def test_conll_proc(self): | |||||
| data_bundle = Conll2003Pipe(num_proc=2).process_from_file('tests/data_for_tests/io/conll2003') | |||||
| print(data_bundle) | |||||
| def test_OntoNotes(self): | def test_OntoNotes(self): | ||||
| data_bundle = OntoNotesNERPipe().process_from_file('tests/data_for_tests/io/OntoNotes') | data_bundle = OntoNotesNERPipe().process_from_file('tests/data_for_tests/io/OntoNotes') | ||||
| print(data_bundle) | print(data_bundle) | ||||
| def test_OntoNotes_proc(self): | |||||
| data_bundle = OntoNotesNERPipe(num_proc=2).process_from_file('tests/data_for_tests/io/OntoNotes') | |||||
| print(data_bundle) | |||||
| @@ -28,12 +28,25 @@ class TestRunCWSPipe: | |||||
| def test_process_from_file(self): | def test_process_from_file(self): | ||||
| dataset_names = ['msra', 'cityu', 'as', 'pku'] | dataset_names = ['msra', 'cityu', 'as', 'pku'] | ||||
| for dataset_name in dataset_names: | for dataset_name in dataset_names: | ||||
| data_bundle = CWSPipe(bigrams=True, trigrams=True).\ | |||||
| data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=0).\ | |||||
| process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | ||||
| print(data_bundle) | print(data_bundle) | ||||
| def test_replace_number(self): | def test_replace_number(self): | ||||
| data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True).\ | |||||
| data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=0).\ | |||||
| process_from_file(f'tests/data_for_tests/io/cws_pku') | |||||
| for word in ['<', '>', '<NUM>']: | |||||
| assert(data_bundle.get_vocab('chars').to_index(word) != 1) | |||||
| def test_process_from_file_proc(self): | |||||
| dataset_names = ['msra', 'cityu', 'as', 'pku'] | |||||
| for dataset_name in dataset_names: | |||||
| data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=2).\ | |||||
| process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | |||||
| print(data_bundle) | |||||
| def test_replace_number_proc(self): | |||||
| data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=2).\ | |||||
| process_from_file(f'tests/data_for_tests/io/cws_pku') | process_from_file(f'tests/data_for_tests/io/cws_pku') | ||||
| for word in ['<', '>', '<NUM>']: | for word in ['<', '>', '<NUM>']: | ||||
| assert(data_bundle.get_vocab('chars').to_index(word) != 1) | assert(data_bundle.get_vocab('chars').to_index(word) != 1) | ||||
| @@ -69,6 +69,47 @@ class TestRunMatchingPipe: | |||||
| name, vocabs = y | name, vocabs = y | ||||
| assert(x + 1 if name == 'words' else x == len(vocabs)) | assert(x + 1 if name == 'words' else x == len(vocabs)) | ||||
| def test_load_proc(self): | |||||
| data_set_dict = { | |||||
| 'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True), | |||||
| 'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), | |||||
| 'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | |||||
| 'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | |||||
| 'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | |||||
| 'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), | |||||
| 'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), | |||||
| } | |||||
| for k, v in data_set_dict.items(): | |||||
| path, pipe1, pipe2, data_set, vocab, warns = v | |||||
| if warns: | |||||
| data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
| data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
| else: | |||||
| data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
| data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path) | |||||
| assert (isinstance(data_bundle1, DataBundle)) | |||||
| assert (len(data_set) == data_bundle1.num_dataset) | |||||
| print(k) | |||||
| print(data_bundle1) | |||||
| print(data_bundle2) | |||||
| for x, y in zip(data_set, data_bundle1.iter_datasets()): | |||||
| name, dataset = y | |||||
| assert (x == len(dataset)) | |||||
| assert (len(data_set) == data_bundle2.num_dataset) | |||||
| for x, y in zip(data_set, data_bundle2.iter_datasets()): | |||||
| name, dataset = y | |||||
| assert (x == len(dataset)) | |||||
| assert (len(vocab) == data_bundle1.num_vocab) | |||||
| for x, y in zip(vocab, data_bundle1.iter_vocabs()): | |||||
| name, vocabs = y | |||||
| assert (x == len(vocabs)) | |||||
| assert (len(vocab) == data_bundle2.num_vocab) | |||||
| for x, y in zip(vocab, data_bundle1.iter_vocabs()): | |||||
| name, vocabs = y | |||||
| assert (x + 1 if name == 'words' else x == len(vocabs)) | |||||
| @pytest.mark.skipif('download' not in os.environ, reason="Skip download") | @pytest.mark.skipif('download' not in os.environ, reason="Skip download") | ||||
| def test_spacy(self): | def test_spacy(self): | ||||
| data_set_dict = { | data_set_dict = { | ||||
| @@ -69,3 +69,45 @@ class TestRunExtCNNDMPipe: | |||||
| db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | ||||
| assert(isinstance(db5, DataBundle)) | assert(isinstance(db5, DataBundle)) | ||||
| def test_load_proc(self): | |||||
| data_dir = 'tests/data_for_tests/io/cnndm' | |||||
| vocab_size = 100000 | |||||
| VOCAL_FILE = 'tests/data_for_tests/io/cnndm/vocab' | |||||
| sent_max_len = 100 | |||||
| doc_max_timesteps = 50 | |||||
| dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
| vocab_path=VOCAL_FILE, | |||||
| sent_max_len=sent_max_len, | |||||
| doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||||
| dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
| vocab_path=VOCAL_FILE, | |||||
| sent_max_len=sent_max_len, | |||||
| doc_max_timesteps=doc_max_timesteps, | |||||
| domain=True, num_proc=2) | |||||
| db = dbPipe.process_from_file(data_dir) | |||||
| db2 = dbPipe2.process_from_file(data_dir) | |||||
| assert(isinstance(db, DataBundle)) | |||||
| assert(isinstance(db2, DataBundle)) | |||||
| dbPipe3 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
| sent_max_len=sent_max_len, | |||||
| doc_max_timesteps=doc_max_timesteps, | |||||
| domain=True, num_proc=2) | |||||
| db3 = dbPipe3.process_from_file(data_dir) | |||||
| assert(isinstance(db3, DataBundle)) | |||||
| with pytest.raises(RuntimeError): | |||||
| dbPipe4 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
| sent_max_len=sent_max_len, | |||||
| doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||||
| db4 = dbPipe4.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||||
| dbPipe5 = ExtCNNDMPipe(vocab_size=vocab_size, | |||||
| vocab_path=VOCAL_FILE, | |||||
| sent_max_len=sent_max_len, | |||||
| doc_max_timesteps=doc_max_timesteps, num_proc=2) | |||||
| db5 = dbPipe5.process_from_file(os.path.join(data_dir, 'train.cnndm.jsonl')) | |||||
| assert(isinstance(db5, DataBundle)) | |||||