Browse Source

fix tests

tags/v0.2.0
yunfan 6 years ago
parent
commit
baac29cfa0
6 changed files with 36 additions and 12 deletions
  1. +1
    -1
      fastNLP/core/instance.py
  2. +2
    -1
      fastNLP/fastnlp.py
  3. +13
    -0
      fastNLP/loader/dataset_loader.py
  4. +1
    -0
      test/core/test_predictor.py
  5. +6
    -3
      test/model/test_cws.py
  6. +13
    -7
      test/model/test_seq_label.py

+ 1
- 1
fastNLP/core/instance.py View File

@@ -20,7 +20,7 @@ class Instance(object):
if old_name in self.indexes:
self.indexes[new_name] = self.indexes.pop(old_name)
else:
print("error, no such field: {}".format(old_name))
raise KeyError("error, no such field: {}".format(old_name))
return self

def set_target(self, **fields):


+ 2
- 1
fastNLP/fastnlp.py View File

@@ -182,7 +182,8 @@ class FastNLP(object):
if self.infer_type in ["seq_label", "text_class"]:
data_set = convert_seq_dataset(infer_input)
data_set.index_field("word_seq", self.word_vocab)
data_set.set_origin_len("word_seq")
if self.infer_type == "seq_label":
data_set.set_origin_len("word_seq")
return data_set
else:
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))


+ 13
- 0
fastNLP/loader/dataset_loader.py View File

@@ -77,6 +77,19 @@ class DataSetLoader(BaseLoader):
def load(self, path):
raise NotImplementedError

class RawDataSetLoader(DataSetLoader):
def __init__(self):
super(RawDataSetLoader, self).__init__()

def load(self, data_path, split=None):
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
lines = lines if split is None else [l.split(split) for l in lines]
lines = list(filter(lambda x: len(x) > 0, lines))
return self.convert(lines)

def convert(self, data):
return convert_seq_dataset(data)

class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.


+ 1
- 0
test/core/test_predictor.py View File

@@ -56,6 +56,7 @@ class TestPredictor(unittest.TestCase):
self.assertTrue(res in class_vocab.word2idx)

del model, predictor
infer_data_set.set_origin_len("word_seq")

model = SeqLabeling(model_args)
predictor = Predictor("./save/", pre.seq_label_post_processor)


+ 6
- 3
test/model/test_cws.py View File

@@ -8,7 +8,7 @@ from fastNLP.core.preprocess import save_pickle, load_pickle
from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader, RawDataSetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver
@@ -38,9 +38,9 @@ def infer():
print("model loaded!")

# Load infer data
infer_data = TokenizeDataSetLoader().load(data_infer_path)
infer_data = RawDataSetLoader().load(data_infer_path)
infer_data.index_field("word_seq", word2index)
infer_data.set_origin_len("word_seq")
# inference
infer = SeqLabelInfer(pickle_path)
results = infer.predict(model, infer_data)
@@ -57,6 +57,9 @@ def train_test():
word_vocab = Vocabulary()
label_vocab = Vocabulary()
data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab)
data_train.set_origin_len("word_seq")
data_train.rename_field("label_seq", "truth").set_target(truth=False)
train_args["vocab_size"] = len(word_vocab)
train_args["num_classes"] = len(label_vocab)



+ 13
- 7
test/model/test_seq_label.py View File

@@ -1,6 +1,7 @@
import os

from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.preprocess import save_pickle
@@ -25,14 +26,19 @@ def test_training():
ConfigLoader().load_config(config_dir, {
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})

data_set = DataSet()
word_vocab = V
data_set = TokenizeDataSetLoader().load(data_path)
word_vocab = Vocabulary()
label_vocab = Vocabulary()
data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab)
data_set.set_origin_len("word_seq")
data_set.rename_field("label_seq", "truth").set_target(truth=False)
data_train, data_dev = data_set.split(0.3, shuffle=True)
model_args["vocab_size"] = len(data_set.word_vocab)
model_args["num_classes"] = len(data_set.label_vocab)
model_args["vocab_size"] = len(word_vocab)
model_args["num_classes"] = len(label_vocab)

save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")
save_pickle(word_vocab, pickle_path, "word2id.pkl")
save_pickle(label_vocab, pickle_path, "label2id.pkl")

trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"],


Loading…
Cancel
Save