Browse Source

移动processor到processor.py

tags/v0.2.0
yh 5 years ago
parent
commit
515e4f4987
2 changed files with 111 additions and 88 deletions
  1. +104
    -1
      fastNLP/api/processor.py
  2. +7
    -87
      reproduction/chinese_word_segment/process/cws_processor.py

+ 104
- 1
fastNLP/api/processor.py View File

@@ -1,4 +1,6 @@

from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary

class Processor:
def __init__(self, field_name, new_added_field_name):
@@ -12,4 +14,105 @@ class Processor:
pass

def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)
return self.process(*args, **kwargs)



class FullSpaceToHalfSpaceProcessor(Processor):
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
change_space=True):
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)

self.change_alpha = change_alpha
self.change_digit = change_digit
self.change_punctuation = change_punctuation
self.change_space = change_space

FH_SPACE = [(u" ", u" ")]
FH_NUM = [
(u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"),
(u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")]
FH_ALPHA = [
(u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"),
(u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"),
(u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"),
(u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"),
(u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"),
(u"z", u"z"),
(u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"),
(u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"),
(u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"),
(u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"),
(u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"),
(u"Z", u"Z")]
# 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震"
FH_PUNCTUATION = [
(u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'),
(u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'),
(u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'),
(u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'),
(u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'),
(u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'),
(u'}', u'}'), (u'|', u'|')]
FHs = []
if self.change_alpha:
FHs = FH_ALPHA
if self.change_digit:
FHs += FH_NUM
if self.change_punctuation:
FHs += FH_PUNCTUATION
if self.change_space:
FHs += FH_SPACE
self.convert_map = {k: v for k, v in FHs}
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
new_sentence = [None]*len(sentence)
for idx, char in enumerate(sentence):
if char in self.convert_map:
char = self.convert_map[char]
new_sentence[idx] = char
ins[self.field_name].text = ''.join(new_sentence)
return dataset


class IndexerProcessor(Processor):
def __init__(self, vocab, field_name):

assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

super(IndexerProcessor, self).__init__(field_name, None)
self.vocab = vocab

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

self.vocab = vocab

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
index = [self.vocab.to_index(token) for token in tokens]
ins[self.field_name]._index = index

return dataset


class VocabProcessor(Processor):
def __init__(self, field_name):

super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()

def process(self, *datasets):
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
self.vocab.update(tokens)

def get_vocab(self):
self.vocab.build_vocab()
return self.vocab

+ 7
- 87
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -11,65 +11,6 @@ from fastNLP.api.processor import Processor

_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>'

class FullSpaceToHalfSpaceProcessor(Processor):
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
change_space=True):
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)

self.change_alpha = change_alpha
self.change_digit = change_digit
self.change_punctuation = change_punctuation
self.change_space = change_space

FH_SPACE = [(u" ", u" ")]
FH_NUM = [
(u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"),
(u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")]
FH_ALPHA = [
(u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"),
(u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"),
(u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"),
(u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"),
(u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"),
(u"z", u"z"),
(u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"),
(u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"),
(u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"),
(u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"),
(u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"),
(u"Z", u"Z")]
# 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震"
FH_PUNCTUATION = [
(u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'),
(u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'),
(u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'),
(u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'),
(u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'),
(u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'),
(u'}', u'}'), (u'|', u'|')]
FHs = []
if self.change_alpha:
FHs = FH_ALPHA
if self.change_digit:
FHs += FH_NUM
if self.change_punctuation:
FHs += FH_PUNCTUATION
if self.change_space:
FHs += FH_SPACE
self.convert_map = {k: v for k, v in FHs}
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
new_sentence = [None]*len(sentence)
for idx, char in enumerate(sentence):
if char in self.convert_map:
char = self.convert_map[char]
new_sentence[idx] = char
ins[self.field_name].text = ''.join(new_sentence)
return dataset


class SpeicalSpanProcessor(Processor):
# 这个类会将句子中的special span转换为对应的内容。
def __init__(self, field_name, new_added_field_name=None):
@@ -93,7 +34,7 @@ class SpeicalSpanProcessor(Processor):
return dataset

def add_span_converter(self, converter):
assert isinstance(converter, SpanConverterBase), "Only SpanConverterBase is allowed, not {}."\
assert isinstance(converter, SpanConverter), "Only SpanConverterBase is allowed, not {}."\
.format(type(converter))
self.span_converters.append(converter)

@@ -243,28 +184,6 @@ class Pre2Post2BigramProcessor(BigramProcessor):
# 这里需要建立vocabulary了,但是遇到了以下的问题
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用
# Processor了
class IndexProcessor(Processor):
def __init__(self, vocab, field_name):

assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

super(IndexProcessor, self).__init__(field_name, None)
self.vocab = vocab

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

self.vocab = vocab

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
index = [self.vocab.to_index(token) for token in tokens]
ins[self.field_name]._index = index

return dataset


class VocabProcessor(Processor):
def __init__(self, field_name):
@@ -272,11 +191,12 @@ class VocabProcessor(Processor):
super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
self.vocab.update(tokens)
def process(self, *datasets):
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
self.vocab.update(tokens)

def get_vocab(self):
self.vocab.build_vocab()


Loading…
Cancel
Save