|
-
- from fastNLP.core.dataset import DataSet
- from fastNLP.core.vocabulary import Vocabulary
-
- import re
-
- class Processor:
- def __init__(self, field_name, new_added_field_name):
- self.field_name = field_name
- if new_added_field_name is None:
- self.new_added_field_name = field_name
- else:
- self.new_added_field_name = new_added_field_name
-
- def process(self):
- pass
-
- def __call__(self, *args, **kwargs):
- return self.process(*args, **kwargs)
-
-
-
- class FullSpaceToHalfSpaceProcessor(Processor):
- def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
- change_space=True):
- super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)
-
- self.change_alpha = change_alpha
- self.change_digit = change_digit
- self.change_punctuation = change_punctuation
- self.change_space = change_space
-
- FH_SPACE = [(u" ", u" ")]
- FH_NUM = [
- (u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"),
- (u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")]
- FH_ALPHA = [
- (u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"),
- (u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"),
- (u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"),
- (u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"),
- (u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"),
- (u"z", u"z"),
- (u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"),
- (u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"),
- (u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"),
- (u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"),
- (u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"),
- (u"Z", u"Z")]
- # 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震"
- FH_PUNCTUATION = [
- (u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'),
- (u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'),
- (u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'),
- (u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'),
- (u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'),
- (u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'),
- (u'}', u'}'), (u'|', u'|')]
- FHs = []
- if self.change_alpha:
- FHs = FH_ALPHA
- if self.change_digit:
- FHs += FH_NUM
- if self.change_punctuation:
- FHs += FH_PUNCTUATION
- if self.change_space:
- FHs += FH_SPACE
- self.convert_map = {k: v for k, v in FHs}
-
- def process(self, dataset):
- assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
- for ins in dataset:
- sentence = ins[self.field_name]
- new_sentence = [None]*len(sentence)
- for idx, char in enumerate(sentence):
- if char in self.convert_map:
- char = self.convert_map[char]
- new_sentence[idx] = char
- ins[self.field_name] = ''.join(new_sentence)
- return dataset
-
-
- class MapFieldProcessor(Processor):
- def __init__(self, func, field_name, new_added_field_name=None):
- super(MapFieldProcessor, self).__init__(field_name, new_added_field_name)
- self.func = func
-
- def process(self, dataset):
- for ins in dataset:
- s = ins[self.field_name]
- new_s = self.func(s)
- ins[self.new_added_field_name] = new_s
- return dataset
-
-
- class Num2TagProcessor(Processor):
- def __init__(self, tag, field_name, new_added_field_name=None):
- super(Num2TagProcessor, self).__init__(field_name, new_added_field_name)
- self.tag = tag
- self.pattern = r'[-+]?[0-9]+[\./e]+[-+]?[0-9]*'
-
- def process(self, dataset):
- for ins in dataset:
- s = ins[self.field_name]
- new_s = [None] * len(s)
- for i, w in enumerate(s):
- if re.search(self.pattern, w) is not None:
- w = self.tag
- new_s[i] = w
- ins[self.new_added_field_name] = new_s
- return dataset
-
-
- class IndexerProcessor(Processor):
- def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False):
-
- assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
-
- super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
- self.vocab = vocab
- self.delete_old_field = delete_old_field
-
- def set_vocab(self, vocab):
- assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
-
- self.vocab = vocab
-
- def process(self, dataset):
- assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
- for ins in dataset:
- tokens = ins[self.field_name]
- index = [self.vocab.to_index(token) for token in tokens]
- ins[self.new_added_field_name] = index
-
- dataset.set_need_tensor(**{self.new_added_field_name:True})
-
- if self.delete_old_field:
- dataset.delete_field(self.field_name)
-
- return dataset
-
-
- class VocabProcessor(Processor):
- def __init__(self, field_name):
-
- super(VocabProcessor, self).__init__(field_name, None)
- self.vocab = Vocabulary()
-
- def process(self, *datasets):
- for dataset in datasets:
- assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
- for ins in dataset:
- tokens = ins[self.field_name]
- self.vocab.update(tokens)
-
- def get_vocab(self):
- self.vocab.build_vocab()
- return self.vocab
|