diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 9c20c2a6..996d0b17 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -14,3 +14,8 @@ class API: _dict = torch.load(name) self.pipeline = _dict['pipeline'] self.model = _dict['model'] + + def save(self, path): + _dict = {'pipeline': self.pipeline, + 'model': self.model} + torch.save(_dict, path) diff --git a/fastNLP/api/parser.py b/fastNLP/api/parser.py new file mode 100644 index 00000000..67bcca4f --- /dev/null +++ b/fastNLP/api/parser.py @@ -0,0 +1,36 @@ +from fastNLP.api.api import API +from fastNLP.core.dataset import DataSet +from fastNLP.core.predictor import Predictor +from fastNLP.api.pipeline import Pipeline +from fastNLP.api.processor import * +from fastNLP.models.biaffine_parser import BiaffineParser + + +class DependencyParser(API): + def __init__(self): + super(DependencyParser, self).__init__() + + def predict(self, data): + self.load('xxx') + + dataset = DataSet() + dataset = self.pipeline.process(dataset) + + pred = Predictor() + res = pred.predict(self.model, dataset) + + return res + + def build(self): + pipe = Pipeline() + + # build pipeline + word_seq = 'word_seq' + pos_seq = 'pos_seq' + pipe.add_processor(Num2TagProcessor('', 'raw_sentence', word_seq)) + pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) + pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) + + # load model parameters + self.model = BiaffineParser() + self.pipeline = pipe diff --git a/fastNLP/api/pipeline.py b/fastNLP/api/pipeline.py index 0edceb19..83aef66e 100644 --- a/fastNLP/api/pipeline.py +++ b/fastNLP/api/pipeline.py @@ -19,4 +19,7 @@ class Pipeline: return dataset def __call__(self, *args, **kwargs): - return self.process(*args, **kwargs) \ No newline at end of file + return self.process(*args, **kwargs) + + def __getitem__(self, item): + return self.pipeline[item] diff --git a/fastNLP/api/pos_tagger.py b/fastNLP/api/pos_tagger.py index fbd689c1..2157231e 100644 --- a/fastNLP/api/pos_tagger.py +++ b/fastNLP/api/pos_tagger.py @@ -5,9 +5,10 @@ import numpy as np from fastNLP.core.dataset import DataSet from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.predictor import Predictor +from fastNLP.api.api import API -class POS_tagger: +class POS_tagger(API): def __init__(self): pass diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 3f8cc057..d21c1050 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -2,6 +2,8 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.vocabulary import Vocabulary +import re + class Processor: def __init__(self, field_name, new_added_field_name): self.field_name = field_name @@ -64,6 +66,7 @@ class FullSpaceToHalfSpaceProcessor(Processor): if self.change_space: FHs += FH_SPACE self.convert_map = {k: v for k, v in FHs} + def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: @@ -77,6 +80,37 @@ class FullSpaceToHalfSpaceProcessor(Processor): return dataset +class MapFieldProcessor(Processor): + def __init__(self, func, field_name, new_added_field_name=None): + super(MapFieldProcessor, self).__init__(field_name, new_added_field_name) + self.func = func + + def process(self, dataset): + for ins in dataset: + s = ins[self.field_name] + new_s = self.func(s) + ins[self.new_added_field_name] = new_s + return dataset + + +class Num2TagProcessor(Processor): + def __init__(self, tag, field_name, new_added_field_name=None): + super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) + self.tag = tag + self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' + + def process(self, dataset): + for ins in dataset: + s = ins[self.field_name] + new_s = [None] * len(s) + for i, w in enumerate(s): + if re.search(self.pattern, w) is not None: + w = self.tag + new_s[i] = w + ins[self.new_added_field_name] = new_s + return dataset + + class IndexerProcessor(Processor): def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False): diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index e3162356..82b55818 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -86,6 +86,8 @@ class DataSet(object): return self.field_arrays[name] def __len__(self): + if len(self.field_arrays) == 0: + return 0 field = iter(self.field_arrays.values()).__next__() return len(field) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index cd68d35d..11cde48a 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -89,7 +89,7 @@ class ConditionalRandomField(nn.Module): score = score.sum(0) + emit_score[-1] if self.include_start_end_trans: st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] - last_idx = masks.long().sum(0) + last_idx = mask.long().sum(0) ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] score += st_scores + ed_scores # return [B,] diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/Biaffine_parser/run.py index 45668066..209e45cb 100644 --- a/reproduction/Biaffine_parser/run.py +++ b/reproduction/Biaffine_parser/run.py @@ -352,7 +352,7 @@ if __name__ == "__main__": elif args.mode == 'test': test(args.path) elif args.mode == 'infer': - infer() + pass else: print('no mode specified for model!') parser.print_help()