diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 9c20c2a6..996d0b17 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -14,3 +14,8 @@ class API: _dict = torch.load(name) self.pipeline = _dict['pipeline'] self.model = _dict['model'] + + def save(self, path): + _dict = {'pipeline': self.pipeline, + 'model': self.model} + torch.save(_dict, path) diff --git a/fastNLP/api/parser.py b/fastNLP/api/parser.py new file mode 100644 index 00000000..6cfdd944 --- /dev/null +++ b/fastNLP/api/parser.py @@ -0,0 +1,31 @@ +from fastNLP.api.api import API +from fastNLP.core.dataset import DataSet +from fastNLP.core.predictor import Predictor +from fastNLP.api.pipeline import Pipeline +from fastNLP.api.processor import * + + +class DependencyParser(API): + def __init__(self): + super(DependencyParser, self).__init__() + + def predict(self, data): + self.load('xxx') + + dataset = DataSet() + dataset = self.pipeline.process(dataset) + + pred = Predictor() + res = pred.predict(self.model, dataset) + + return res + + def build(self): + pipe = Pipeline() + + word_seq = 'word_seq' + pos_seq = 'pos_seq' + pipe.add_processor(Num2TagProcessor('', word_seq)) + pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) + pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) + pipe.add_processor() diff --git a/fastNLP/api/pipeline.py b/fastNLP/api/pipeline.py index 745c8874..5e68022a 100644 --- a/fastNLP/api/pipeline.py +++ b/fastNLP/api/pipeline.py @@ -19,4 +19,7 @@ class Pipeline: return dataset def __call__(self, *args, **kwargs): - return self.process(*args, **kwargs) \ No newline at end of file + return self.process(*args, **kwargs) + + def __getitem__(self, item): + return self.pipeline[item] diff --git a/fastNLP/api/pos_tagger.py b/fastNLP/api/pos_tagger.py index fbd689c1..2157231e 100644 --- a/fastNLP/api/pos_tagger.py +++ b/fastNLP/api/pos_tagger.py @@ -5,9 +5,10 @@ import numpy as np from fastNLP.core.dataset import DataSet from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.predictor import Predictor +from fastNLP.api.api import API -class POS_tagger: +class POS_tagger(API): def __init__(self): pass diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 3f8cc057..24c98d1a 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -2,6 +2,8 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.vocabulary import Vocabulary +import re + class Processor: def __init__(self, field_name, new_added_field_name): self.field_name = field_name @@ -64,6 +66,7 @@ class FullSpaceToHalfSpaceProcessor(Processor): if self.change_space: FHs += FH_SPACE self.convert_map = {k: v for k, v in FHs} + def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: @@ -77,6 +80,37 @@ class FullSpaceToHalfSpaceProcessor(Processor): return dataset +class MapFieldProcessor(Processor): + def __init__(self, func, field_name, new_added_field_name=None): + super(MapFieldProcessor, self).__init__(field_name, new_added_field_name) + self.func = func + + def process(self, dataset): + for ins in dataset: + s = ins[self.field_name] + new_s = self.func(s) + ins[self.new_added_field_name] = new_s + return dataset + + +class Num2TagProcessor(Processor): + def __init__(self, tag, field_name, new_added_field_name=None): + super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) + self.tag = tag + self.pattern = r'[-+]?[0-9]+[\./e]+[-+]?[0-9]*' + + def process(self, dataset): + for ins in dataset: + s = ins[self.field_name] + new_s = [None] * len(s) + for i, w in enumerate(s): + if re.search(self.pattern, w) is not None: + w = self.tag + new_s[i] = w + ins[self.new_added_field_name] = new_s + return dataset + + class IndexerProcessor(Processor): def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False):