From 515e4f4987106009d30e53c0865c89a389712d17 Mon Sep 17 00:00:00 2001 From: yh Date: Fri, 9 Nov 2018 22:02:10 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A7=BB=E5=8A=A8processor=E5=88=B0processor.p?= =?UTF-8?q?y?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/api/processor.py | 105 +++++++++++++++++- .../process/cws_processor.py | 94 ++-------------- 2 files changed, 111 insertions(+), 88 deletions(-) diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 793cfe10..a01810ac 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -1,4 +1,6 @@ +from fastNLP.core.dataset import DataSet +from fastNLP.core.vocabulary import Vocabulary class Processor: def __init__(self, field_name, new_added_field_name): @@ -12,4 +14,105 @@ class Processor: pass def __call__(self, *args, **kwargs): - return self.process(*args, **kwargs) \ No newline at end of file + return self.process(*args, **kwargs) + + + +class FullSpaceToHalfSpaceProcessor(Processor): + def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, + change_space=True): + super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) + + self.change_alpha = change_alpha + self.change_digit = change_digit + self.change_punctuation = change_punctuation + self.change_space = change_space + + FH_SPACE = [(u" ", u" ")] + FH_NUM = [ + (u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"), + (u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")] + FH_ALPHA = [ + (u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"), + (u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"), + (u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"), + (u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"), + (u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"), + (u"z", u"z"), + (u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"), + (u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"), + (u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"), + (u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"), + (u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"), + (u"Z", u"Z")] + # 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震" + FH_PUNCTUATION = [ + (u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'), + (u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'), + (u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'), + (u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'), + (u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'), + (u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'), + (u'}', u'}'), (u'|', u'|')] + FHs = [] + if self.change_alpha: + FHs = FH_ALPHA + if self.change_digit: + FHs += FH_NUM + if self.change_punctuation: + FHs += FH_PUNCTUATION + if self.change_space: + FHs += FH_SPACE + self.convert_map = {k: v for k, v in FHs} + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + sentence = ins[self.field_name].text + new_sentence = [None]*len(sentence) + for idx, char in enumerate(sentence): + if char in self.convert_map: + char = self.convert_map[char] + new_sentence[idx] = char + ins[self.field_name].text = ''.join(new_sentence) + return dataset + + +class IndexerProcessor(Processor): + def __init__(self, vocab, field_name): + + assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) + + super(IndexerProcessor, self).__init__(field_name, None) + self.vocab = vocab + + def set_vocab(self, vocab): + assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) + + self.vocab = vocab + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + tokens = ins[self.field_name].content + index = [self.vocab.to_index(token) for token in tokens] + ins[self.field_name]._index = index + + return dataset + + +class VocabProcessor(Processor): + def __init__(self, field_name): + + super(VocabProcessor, self).__init__(field_name, None) + self.vocab = Vocabulary() + + def process(self, *datasets): + for dataset in datasets: + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + tokens = ins[self.field_name].content + self.vocab.update(tokens) + + def get_vocab(self): + self.vocab.build_vocab() + return self.vocab diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index 1f7c0fc1..bb76b974 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -11,65 +11,6 @@ from fastNLP.api.processor import Processor _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' -class FullSpaceToHalfSpaceProcessor(Processor): - def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, - change_space=True): - super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) - - self.change_alpha = change_alpha - self.change_digit = change_digit - self.change_punctuation = change_punctuation - self.change_space = change_space - - FH_SPACE = [(u" ", u" ")] - FH_NUM = [ - (u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"), - (u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")] - FH_ALPHA = [ - (u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"), - (u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"), - (u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"), - (u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"), - (u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"), - (u"z", u"z"), - (u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"), - (u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"), - (u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"), - (u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"), - (u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"), - (u"Z", u"Z")] - # 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震" - FH_PUNCTUATION = [ - (u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'), - (u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'), - (u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'), - (u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'), - (u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'), - (u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'), - (u'}', u'}'), (u'|', u'|')] - FHs = [] - if self.change_alpha: - FHs = FH_ALPHA - if self.change_digit: - FHs += FH_NUM - if self.change_punctuation: - FHs += FH_PUNCTUATION - if self.change_space: - FHs += FH_SPACE - self.convert_map = {k: v for k, v in FHs} - def process(self, dataset): - assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: - sentence = ins[self.field_name].text - new_sentence = [None]*len(sentence) - for idx, char in enumerate(sentence): - if char in self.convert_map: - char = self.convert_map[char] - new_sentence[idx] = char - ins[self.field_name].text = ''.join(new_sentence) - return dataset - - class SpeicalSpanProcessor(Processor): # 这个类会将句子中的special span转换为对应的内容。 def __init__(self, field_name, new_added_field_name=None): @@ -93,7 +34,7 @@ class SpeicalSpanProcessor(Processor): return dataset def add_span_converter(self, converter): - assert isinstance(converter, SpanConverterBase), "Only SpanConverterBase is allowed, not {}."\ + assert isinstance(converter, SpanConverter), "Only SpanConverterBase is allowed, not {}."\ .format(type(converter)) self.span_converters.append(converter) @@ -243,28 +184,6 @@ class Pre2Post2BigramProcessor(BigramProcessor): # 这里需要建立vocabulary了,但是遇到了以下的问题 # (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 # Processor了 -class IndexProcessor(Processor): - def __init__(self, vocab, field_name): - - assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) - - super(IndexProcessor, self).__init__(field_name, None) - self.vocab = vocab - - def set_vocab(self, vocab): - assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) - - self.vocab = vocab - - def process(self, dataset): - assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: - tokens = ins[self.field_name].content - index = [self.vocab.to_index(token) for token in tokens] - ins[self.field_name]._index = index - - return dataset - class VocabProcessor(Processor): def __init__(self, field_name): @@ -272,11 +191,12 @@ class VocabProcessor(Processor): super(VocabProcessor, self).__init__(field_name, None) self.vocab = Vocabulary() - def process(self, dataset): - assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: - tokens = ins[self.field_name].content - self.vocab.update(tokens) + def process(self, *datasets): + for dataset in datasets: + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + tokens = ins[self.field_name].content + self.vocab.update(tokens) def get_vocab(self): self.vocab.build_vocab()