@@ -14,3 +14,8 @@ class API: | |||||
_dict = torch.load(name) | _dict = torch.load(name) | ||||
self.pipeline = _dict['pipeline'] | self.pipeline = _dict['pipeline'] | ||||
self.model = _dict['model'] | self.model = _dict['model'] | ||||
def save(self, path): | |||||
_dict = {'pipeline': self.pipeline, | |||||
'model': self.model} | |||||
torch.save(_dict, path) |
@@ -0,0 +1,36 @@ | |||||
from fastNLP.api.api import API | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.predictor import Predictor | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import * | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
class DependencyParser(API): | |||||
def __init__(self): | |||||
super(DependencyParser, self).__init__() | |||||
def predict(self, data): | |||||
self.load('xxx') | |||||
dataset = DataSet() | |||||
dataset = self.pipeline.process(dataset) | |||||
pred = Predictor() | |||||
res = pred.predict(self.model, dataset) | |||||
return res | |||||
def build(self): | |||||
pipe = Pipeline() | |||||
# build pipeline | |||||
word_seq = 'word_seq' | |||||
pos_seq = 'pos_seq' | |||||
pipe.add_processor(Num2TagProcessor('<NUM>', 'raw_sentence', word_seq)) | |||||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) | |||||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) | |||||
# load model parameters | |||||
self.model = BiaffineParser() | |||||
self.pipeline = pipe |
@@ -19,4 +19,7 @@ class Pipeline: | |||||
return dataset | return dataset | ||||
def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
return self.process(*args, **kwargs) | |||||
return self.process(*args, **kwargs) | |||||
def __getitem__(self, item): | |||||
return self.pipeline[item] |
@@ -5,9 +5,10 @@ import numpy as np | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.predictor import Predictor | from fastNLP.core.predictor import Predictor | ||||
from fastNLP.api.api import API | |||||
class POS_tagger: | |||||
class POS_tagger(API): | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -2,6 +2,8 @@ | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
import re | |||||
class Processor: | class Processor: | ||||
def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
self.field_name = field_name | self.field_name = field_name | ||||
@@ -64,6 +66,7 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
if self.change_space: | if self.change_space: | ||||
FHs += FH_SPACE | FHs += FH_SPACE | ||||
self.convert_map = {k: v for k, v in FHs} | self.convert_map = {k: v for k, v in FHs} | ||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
@@ -77,6 +80,37 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
return dataset | return dataset | ||||
class MapFieldProcessor(Processor): | |||||
def __init__(self, func, field_name, new_added_field_name=None): | |||||
super(MapFieldProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.func = func | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
s = ins[self.field_name] | |||||
new_s = self.func(s) | |||||
ins[self.new_added_field_name] = new_s | |||||
return dataset | |||||
class Num2TagProcessor(Processor): | |||||
def __init__(self, tag, field_name, new_added_field_name=None): | |||||
super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag = tag | |||||
self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
s = ins[self.field_name] | |||||
new_s = [None] * len(s) | |||||
for i, w in enumerate(s): | |||||
if re.search(self.pattern, w) is not None: | |||||
w = self.tag | |||||
new_s[i] = w | |||||
ins[self.new_added_field_name] = new_s | |||||
return dataset | |||||
class IndexerProcessor(Processor): | class IndexerProcessor(Processor): | ||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False): | def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False): | ||||
@@ -86,6 +86,8 @@ class DataSet(object): | |||||
return self.field_arrays[name] | return self.field_arrays[name] | ||||
def __len__(self): | def __len__(self): | ||||
if len(self.field_arrays) == 0: | |||||
return 0 | |||||
field = iter(self.field_arrays.values()).__next__() | field = iter(self.field_arrays.values()).__next__() | ||||
return len(field) | return len(field) | ||||
@@ -89,7 +89,7 @@ class ConditionalRandomField(nn.Module): | |||||
score = score.sum(0) + emit_score[-1] | score = score.sum(0) + emit_score[-1] | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | ||||
last_idx = masks.long().sum(0) | |||||
last_idx = mask.long().sum(0) | |||||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | ||||
score += st_scores + ed_scores | score += st_scores + ed_scores | ||||
# return [B,] | # return [B,] | ||||
@@ -352,7 +352,7 @@ if __name__ == "__main__": | |||||
elif args.mode == 'test': | elif args.mode == 'test': | ||||
test(args.path) | test(args.path) | ||||
elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
infer() | |||||
pass | |||||
else: | else: | ||||
print('no mode specified for model!') | print('no mode specified for model!') | ||||
parser.print_help() | parser.print_help() |