From 1806bbdbec72ebc926348bc70ae98739b699fbf2 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 10 Nov 2018 15:13:53 +0800 Subject: [PATCH] fix dataset --- fastNLP/api/parser.py | 9 +++++++-- fastNLP/core/dataset.py | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/fastNLP/api/parser.py b/fastNLP/api/parser.py index 6cfdd944..67bcca4f 100644 --- a/fastNLP/api/parser.py +++ b/fastNLP/api/parser.py @@ -3,6 +3,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.predictor import Predictor from fastNLP.api.pipeline import Pipeline from fastNLP.api.processor import * +from fastNLP.models.biaffine_parser import BiaffineParser class DependencyParser(API): @@ -23,9 +24,13 @@ class DependencyParser(API): def build(self): pipe = Pipeline() + # build pipeline word_seq = 'word_seq' pos_seq = 'pos_seq' - pipe.add_processor(Num2TagProcessor('', word_seq)) + pipe.add_processor(Num2TagProcessor('', 'raw_sentence', word_seq)) pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) - pipe.add_processor() + + # load model parameters + self.model = BiaffineParser() + self.pipeline = pipe diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index e3162356..82b55818 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -86,6 +86,8 @@ class DataSet(object): return self.field_arrays[name] def __len__(self): + if len(self.field_arrays) == 0: + return 0 field = iter(self.field_arrays.values()).__next__() return len(field)