@@ -20,7 +20,7 @@ class Instance(object): | |||||
if old_name in self.indexes: | if old_name in self.indexes: | ||||
self.indexes[new_name] = self.indexes.pop(old_name) | self.indexes[new_name] = self.indexes.pop(old_name) | ||||
else: | else: | ||||
print("error, no such field: {}".format(old_name)) | |||||
raise KeyError("error, no such field: {}".format(old_name)) | |||||
return self | return self | ||||
def set_target(self, **fields): | def set_target(self, **fields): | ||||
@@ -182,7 +182,8 @@ class FastNLP(object): | |||||
if self.infer_type in ["seq_label", "text_class"]: | if self.infer_type in ["seq_label", "text_class"]: | ||||
data_set = convert_seq_dataset(infer_input) | data_set = convert_seq_dataset(infer_input) | ||||
data_set.index_field("word_seq", self.word_vocab) | data_set.index_field("word_seq", self.word_vocab) | ||||
data_set.set_origin_len("word_seq") | |||||
if self.infer_type == "seq_label": | |||||
data_set.set_origin_len("word_seq") | |||||
return data_set | return data_set | ||||
else: | else: | ||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | ||||
@@ -77,6 +77,19 @@ class DataSetLoader(BaseLoader): | |||||
def load(self, path): | def load(self, path): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class RawDataSetLoader(DataSetLoader): | |||||
def __init__(self): | |||||
super(RawDataSetLoader, self).__init__() | |||||
def load(self, data_path, split=None): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
lines = lines if split is None else [l.split(split) for l in lines] | |||||
lines = list(filter(lambda x: len(x) > 0, lines)) | |||||
return self.convert(lines) | |||||
def convert(self, data): | |||||
return convert_seq_dataset(data) | |||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
@@ -56,6 +56,7 @@ class TestPredictor(unittest.TestCase): | |||||
self.assertTrue(res in class_vocab.word2idx) | self.assertTrue(res in class_vocab.word2idx) | ||||
del model, predictor | del model, predictor | ||||
infer_data_set.set_origin_len("word_seq") | |||||
model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
predictor = Predictor("./save/", pre.seq_label_post_processor) | predictor = Predictor("./save/", pre.seq_label_post_processor) | ||||
@@ -8,7 +8,7 @@ from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader | |||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader, RawDataSetLoader | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
@@ -38,9 +38,9 @@ def infer(): | |||||
print("model loaded!") | print("model loaded!") | ||||
# Load infer data | # Load infer data | ||||
infer_data = TokenizeDataSetLoader().load(data_infer_path) | |||||
infer_data = RawDataSetLoader().load(data_infer_path) | |||||
infer_data.index_field("word_seq", word2index) | infer_data.index_field("word_seq", word2index) | ||||
infer_data.set_origin_len("word_seq") | |||||
# inference | # inference | ||||
infer = SeqLabelInfer(pickle_path) | infer = SeqLabelInfer(pickle_path) | ||||
results = infer.predict(model, infer_data) | results = infer.predict(model, infer_data) | ||||
@@ -57,6 +57,9 @@ def train_test(): | |||||
word_vocab = Vocabulary() | word_vocab = Vocabulary() | ||||
label_vocab = Vocabulary() | label_vocab = Vocabulary() | ||||
data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | ||||
data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||||
data_train.set_origin_len("word_seq") | |||||
data_train.rename_field("label_seq", "truth").set_target(truth=False) | |||||
train_args["vocab_size"] = len(word_vocab) | train_args["vocab_size"] = len(word_vocab) | ||||
train_args["num_classes"] = len(label_vocab) | train_args["num_classes"] = len(label_vocab) | ||||
@@ -1,6 +1,7 @@ | |||||
import os | import os | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
from fastNLP.core.preprocess import save_pickle | from fastNLP.core.preprocess import save_pickle | ||||
@@ -25,14 +26,19 @@ def test_training(): | |||||
ConfigLoader().load_config(config_dir, { | ConfigLoader().load_config(config_dir, { | ||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | ||||
data_set = DataSet() | |||||
word_vocab = V | |||||
data_set = TokenizeDataSetLoader().load(data_path) | |||||
word_vocab = Vocabulary() | |||||
label_vocab = Vocabulary() | |||||
data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||||
data_set.set_origin_len("word_seq") | |||||
data_set.rename_field("label_seq", "truth").set_target(truth=False) | |||||
data_train, data_dev = data_set.split(0.3, shuffle=True) | data_train, data_dev = data_set.split(0.3, shuffle=True) | ||||
model_args["vocab_size"] = len(data_set.word_vocab) | |||||
model_args["num_classes"] = len(data_set.label_vocab) | |||||
model_args["vocab_size"] = len(word_vocab) | |||||
model_args["num_classes"] = len(label_vocab) | |||||
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||||
save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||||
trainer = SeqLabelTrainer( | trainer = SeqLabelTrainer( | ||||
epochs=trainer_args["epochs"], | epochs=trainer_args["epochs"], | ||||