@@ -69,6 +69,6 @@ class Batch(object): | |||||
else: | else: | ||||
batch[name] = torch.stack(tensor_list, dim=0) | batch[name] = torch.stack(tensor_list, dim=0) | ||||
self.curidx += endidx | |||||
self.curidx = endidx | |||||
return batch_x, batch_y | return batch_x, batch_y | ||||
@@ -144,6 +144,15 @@ class DataSet(list): | |||||
else: | else: | ||||
self.convert(raw_data) | self.convert(raw_data) | ||||
def load_raw(self, raw_data, vocabs): | |||||
""" | |||||
:param raw_data: | |||||
:param vocabs: | |||||
:return: | |||||
""" | |||||
self.convert_for_infer(raw_data, vocabs) | |||||
def split(self, ratio, shuffle=True): | def split(self, ratio, shuffle=True): | ||||
"""Train/dev splitting | """Train/dev splitting | ||||
@@ -38,14 +38,19 @@ class SeqLabelEvaluator(Evaluator): | |||||
def __call__(self, predict, truth): | def __call__(self, predict, truth): | ||||
""" | """ | ||||
:param predict: list of tensors, the network outputs from all batches. | |||||
:param predict: list of List, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | :param truth: list of dict, the ground truths from all batch_y. | ||||
:return accuracy: | :return accuracy: | ||||
""" | """ | ||||
truth = [item["truth"] for item in truth] | truth = [item["truth"] for item in truth] | ||||
truth = torch.cat(truth).view(-1, ) | |||||
results = torch.Tensor(predict).view(-1, ) | |||||
accuracy = torch.sum(results.to(truth) == truth).to(torch.float) / results.shape[0] | |||||
total_correct, total_count= 0., 0. | |||||
for x, y in zip(predict, truth): | |||||
mask = torch.Tensor(x).ge(1) | |||||
correct = torch.sum(torch.Tensor(x) * mask.float() == (y * mask.long()).float()) | |||||
correct -= torch.sum(torch.Tensor(x).le(0)) | |||||
total_correct += float(correct) | |||||
total_count += float(torch.sum(mask)) | |||||
accuracy = total_correct / total_count | |||||
return {"accuracy": float(accuracy)} | return {"accuracy": float(accuracy)} | ||||
@@ -34,7 +34,7 @@ class Predictor(object): | |||||
"""Perform inference using the trained model. | """Perform inference using the trained model. | ||||
:param network: a PyTorch model (cpu) | :param network: a PyTorch model (cpu) | ||||
:param data: list of list of strings, [num_examples, seq_len] | |||||
:param data: a DataSet object. | |||||
:return: list of list of strings, [num_examples, tag_seq_length] | :return: list of list of strings, [num_examples, tag_seq_length] | ||||
""" | """ | ||||
# transform strings into DataSet object | # transform strings into DataSet object | ||||
@@ -18,6 +18,9 @@ def save_pickle(obj, pickle_path, file_name): | |||||
:param pickle_path: str, the directory where the pickle file is to be saved | :param pickle_path: str, the directory where the pickle file is to be saved | ||||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | ||||
""" | """ | ||||
if not os.path.exists(pickle_path): | |||||
os.mkdir(pickle_path) | |||||
print("make dir {} before saving pickle file".format(pickle_path)) | |||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | with open(os.path.join(pickle_path, file_name), "wb") as f: | ||||
_pickle.dump(obj, f) | _pickle.dump(obj, f) | ||||
print("{} saved in {}".format(file_name, pickle_path)) | print("{} saved in {}".format(file_name, pickle_path)) | ||||
@@ -4,6 +4,8 @@ from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | |||||
from fastNLP.core.preprocess import load_pickle | from fastNLP.core.preprocess import load_pickle | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||||
""" | """ | ||||
mapping from model name to [URL, file_name.class_name, model_pickle_name] | mapping from model name to [URL, file_name.class_name, model_pickle_name] | ||||
@@ -76,6 +78,8 @@ class FastNLP(object): | |||||
self.model_dir = model_dir | self.model_dir = model_dir | ||||
self.model = None | self.model = None | ||||
self.infer_type = None # "seq_label"/"text_class" | self.infer_type = None # "seq_label"/"text_class" | ||||
self.word_vocab = None | |||||
self.label_vocab = None | |||||
def load(self, model_name, config_file="config", section_name="model"): | def load(self, model_name, config_file="config", section_name="model"): | ||||
""" | """ | ||||
@@ -100,10 +104,10 @@ class FastNLP(object): | |||||
print("Restore model hyper-parameters {}".format(str(model_args.data))) | print("Restore model hyper-parameters {}".format(str(model_args.data))) | ||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word_vocab = load_pickle(self.model_dir, "word2id.pkl") | |||||
model_args["vocab_size"] = len(word_vocab) | |||||
label_vocab = load_pickle(self.model_dir, "class2id.pkl") | |||||
model_args["num_classes"] = len(label_vocab) | |||||
self.word_vocab = load_pickle(self.model_dir, "word2id.pkl") | |||||
model_args["vocab_size"] = len(self.word_vocab) | |||||
self.label_vocab = load_pickle(self.model_dir, "label2id.pkl") | |||||
model_args["num_classes"] = len(self.label_vocab) | |||||
# Construct the model | # Construct the model | ||||
model = model_class(model_args) | model = model_class(model_args) | ||||
@@ -130,8 +134,11 @@ class FastNLP(object): | |||||
# tokenize: list of string ---> 2-D list of string | # tokenize: list of string ---> 2-D list of string | ||||
infer_input = self.tokenize(raw_input, language="zh") | infer_input = self.tokenize(raw_input, language="zh") | ||||
# 2-D list of string ---> 2-D list of tags | |||||
results = infer.predict(self.model, infer_input) | |||||
# create DataSet: 2-D list of strings ----> DataSet | |||||
infer_data = self._create_data_set(infer_input) | |||||
# DataSet ---> 2-D list of tags | |||||
results = infer.predict(self.model, infer_data) | |||||
# 2-D list of tags ---> list of final answers | # 2-D list of tags ---> list of final answers | ||||
outputs = self._make_output(results, infer_input) | outputs = self._make_output(results, infer_input) | ||||
@@ -154,6 +161,11 @@ class FastNLP(object): | |||||
return module | return module | ||||
def _create_inference(self, model_dir): | def _create_inference(self, model_dir): | ||||
"""Specify which task to perform. | |||||
:param model_dir: | |||||
:return: | |||||
""" | |||||
if self.infer_type == "seq_label": | if self.infer_type == "seq_label": | ||||
return SeqLabelInfer(model_dir) | return SeqLabelInfer(model_dir) | ||||
elif self.infer_type == "text_class": | elif self.infer_type == "text_class": | ||||
@@ -161,6 +173,24 @@ class FastNLP(object): | |||||
else: | else: | ||||
raise ValueError("fail to create inference instance") | raise ValueError("fail to create inference instance") | ||||
def _create_data_set(self, infer_input): | |||||
"""Create a DataSet object given the raw inputs. | |||||
:param infer_input: 2-D lists of strings | |||||
:return data_set: a DataSet object | |||||
""" | |||||
if self.infer_type == "seq_label": | |||||
data_set = SeqLabelDataSet() | |||||
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) | |||||
return data_set | |||||
elif self.infer_type == "text_class": | |||||
data_set = TextClassifyDataSet() | |||||
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) | |||||
return data_set | |||||
else: | |||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | |||||
def _load(self, model_dir, model_name): | def _load(self, model_dir, model_name): | ||||
# To do | # To do | ||||
return 0 | return 0 | ||||
@@ -18,7 +18,7 @@ class ConfigSaver(object): | |||||
:return: The section. | :return: The section. | ||||
""" | """ | ||||
sect = ConfigSection() | sect = ConfigSection() | ||||
ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) | |||||
ConfigLoader().load_config(self.file_path, {sect_name: sect}) | |||||
return sect | return sect | ||||
def _read_section(self): | def _read_section(self): | ||||
@@ -43,8 +43,10 @@ class TestCase1(unittest.TestCase): | |||||
# use batch to iterate dataset | # use batch to iterate dataset | ||||
data_iterator = Batch(data, 2, SeqSampler(), False) | data_iterator = Batch(data, 2, SeqSampler(), False) | ||||
total_data = 0 | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
self.assertEqual(len(batch_x), 2) | |||||
total_data += batch_x["text"].size(0) | |||||
self.assertTrue(batch_x["text"].size(0) == 2 or total_data == len(raw_texts)) | |||||
self.assertTrue(isinstance(batch_x, dict)) | self.assertTrue(isinstance(batch_x, dict)) | ||||
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | ||||
self.assertTrue(isinstance(batch_y, dict)) | self.assertTrue(isinstance(batch_y, dict)) | ||||
@@ -1,20 +1,42 @@ | |||||
import sys, os | |||||
import os | |||||
import sys | |||||
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | ||||
from fastNLP.core import metrics | from fastNLP.core import metrics | ||||
# from sklearn import metrics as skmetrics | # from sklearn import metrics as skmetrics | ||||
import unittest | import unittest | ||||
import numpy as np | |||||
from numpy import random | from numpy import random | ||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
import torch | |||||
def generate_fake_label(low, high, size): | def generate_fake_label(low, high, size): | ||||
return random.randint(low, high, size), random.randint(low, high, size) | return random.randint(low, high, size), random.randint(low, high, size) | ||||
class TestEvaluator(unittest.TestCase): | |||||
def test_a(self): | |||||
evaluator = SeqLabelEvaluator() | |||||
pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] | |||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}] | |||||
ans = evaluator(pred, truth) | |||||
print(ans) | |||||
def test_b(self): | |||||
evaluator = SeqLabelEvaluator() | |||||
pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]] | |||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}] | |||||
ans = evaluator(pred, truth) | |||||
print(ans) | |||||
class TestMetrics(unittest.TestCase): | class TestMetrics(unittest.TestCase): | ||||
delta = 1e-5 | delta = 1e-5 | ||||
# test for binary, multiclass, multilabel | # test for binary, multiclass, multilabel | ||||
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | ||||
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | ||||
def test_accuracy_score(self): | def test_accuracy_score(self): | ||||
for y_true, y_pred in self.fake_data: | for y_true, y_pred in self.fake_data: | ||||
for normalize in [True, False]: | for normalize in [True, False]: | ||||
@@ -22,7 +44,7 @@ class TestMetrics(unittest.TestCase): | |||||
test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | ||||
# ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | # ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | ||||
# self.assertAlmostEqual(test, ans, delta=self.delta) | # self.assertAlmostEqual(test, ans, delta=self.delta) | ||||
def test_recall_score(self): | def test_recall_score(self): | ||||
for y_true, y_pred in self.fake_data: | for y_true, y_pred in self.fake_data: | ||||
# print(y_true.shape) | # print(y_true.shape) | ||||
@@ -73,5 +95,6 @@ class TestMetrics(unittest.TestCase): | |||||
# ans = skmetrics.f1_score(y_true, y_pred) | # ans = skmetrics.f1_score(y_true, y_pred) | ||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | # self.assertAlmostEqual(ans, test, delta=self.delta) | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
unittest.main() | unittest.main() |
@@ -2,9 +2,12 @@ import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.predictor import Predictor | from fastNLP.core.predictor import Predictor | ||||
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet | |||||
from fastNLP.core.preprocess import save_pickle | from fastNLP.core.preprocess import save_pickle | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
@@ -28,23 +31,44 @@ class TestPredictor(unittest.TestCase): | |||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | ||||
class_vocab = Vocabulary() | class_vocab = Vocabulary() | ||||
class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} | |||||
class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||||
os.system("mkdir save") | os.system("mkdir save") | ||||
save_pickle(class_vocab, "./save/", "class2id.pkl") | |||||
save_pickle(class_vocab, "./save/", "label2id.pkl") | |||||
save_pickle(vocab, "./save/", "word2id.pkl") | save_pickle(vocab, "./save/", "word2id.pkl") | ||||
model = SeqLabeling(model_args) | |||||
predictor = Predictor("./save/", task="seq_label") | |||||
model = CNNText(model_args) | |||||
import fastNLP.core.predictor as pre | |||||
predictor = Predictor("./save/", pre.text_classify_post_processor) | |||||
results = predictor.predict(network=model, data=infer_data) | |||||
# Load infer data | |||||
infer_data_set = TextClassifyDataSet(loader=BaseLoader()) | |||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||||
results = predictor.predict(network=model, data=infer_data_set) | |||||
self.assertTrue(isinstance(results, list)) | self.assertTrue(isinstance(results, list)) | ||||
self.assertGreater(len(results), 0) | self.assertGreater(len(results), 0) | ||||
self.assertEqual(len(results), len(infer_data)) | |||||
for res in results: | for res in results: | ||||
self.assertTrue(isinstance(res, str)) | |||||
self.assertTrue(res in class_vocab.word2idx) | |||||
del model, predictor, infer_data_set | |||||
model = SeqLabeling(model_args) | |||||
predictor = Predictor("./save/", pre.seq_label_post_processor) | |||||
infer_data_set = SeqLabelDataSet(loader=BaseLoader()) | |||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||||
results = predictor.predict(network=model, data=infer_data_set) | |||||
self.assertTrue(isinstance(results, list)) | |||||
self.assertEqual(len(results), len(infer_data)) | |||||
for i in range(len(infer_data)): | |||||
res = results[i] | |||||
self.assertTrue(isinstance(res, list)) | self.assertTrue(isinstance(res, list)) | ||||
self.assertEqual(len(res), 5) | |||||
self.assertTrue(isinstance(res[0], str)) | |||||
self.assertEqual(len(res), len(infer_data[i])) | |||||
os.system("rm -rf save") | os.system("rm -rf save") | ||||
print("pickle path deleted") | print("pickle path deleted") | ||||
@@ -1,8 +1,9 @@ | |||||
import os | import os | ||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.field import TextField | |||||
from fastNLP.core.dataset import SeqLabelDataSet | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
@@ -21,7 +22,7 @@ class TestTester(unittest.TestCase): | |||||
} | } | ||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | ||||
"save_loss": True, "batch_size": 2, "pickle_path": "./save/", | "save_loss": True, "batch_size": 2, "pickle_path": "./save/", | ||||
"use_cuda": False, "print_every_step": 1} | |||||
"use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()} | |||||
train_data = [ | train_data = [ | ||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | ||||
@@ -34,16 +35,17 @@ class TestTester(unittest.TestCase): | |||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | ||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | ||||
data_set = DataSet() | |||||
data_set = SeqLabelDataSet() | |||||
for example in train_data: | for example in train_data: | ||||
text, label = example[0], example[1] | text, label = example[0], example[1] | ||||
x = TextField(text, False) | x = TextField(text, False) | ||||
x_len = LabelField(len(text), is_target=False) | |||||
y = TextField(label, is_target=True) | y = TextField(label, is_target=True) | ||||
ins = Instance(word_seq=x, label_seq=y) | |||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
data_set.append(ins) | data_set.append(ins) | ||||
data_set.index_field("word_seq", vocab) | data_set.index_field("word_seq", vocab) | ||||
data_set.index_field("label_seq", label_vocab) | |||||
data_set.index_field("truth", label_vocab) | |||||
model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
@@ -1,8 +1,9 @@ | |||||
import os | import os | ||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.field import TextField | |||||
from fastNLP.core.dataset import SeqLabelDataSet | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
@@ -12,14 +13,15 @@ from fastNLP.models.sequence_modeling import SeqLabeling | |||||
class TestTrainer(unittest.TestCase): | class TestTrainer(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/", | |||||
args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||||
"save_best_dev": True, "model_name": "default_model_name.pkl", | "save_best_dev": True, "model_name": "default_model_name.pkl", | ||||
"loss": Loss(None), | |||||
"loss": Loss("cross_entropy"), | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | ||||
"vocab_size": 10, | "vocab_size": 10, | ||||
"word_emb_dim": 100, | "word_emb_dim": 100, | ||||
"rnn_hidden_units": 100, | "rnn_hidden_units": 100, | ||||
"num_classes": 5 | |||||
"num_classes": 5, | |||||
"evaluator": SeqLabelEvaluator() | |||||
} | } | ||||
trainer = SeqLabelTrainer(**args) | trainer = SeqLabelTrainer(**args) | ||||
@@ -34,16 +36,17 @@ class TestTrainer(unittest.TestCase): | |||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | ||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | ||||
data_set = DataSet() | |||||
data_set = SeqLabelDataSet() | |||||
for example in train_data: | for example in train_data: | ||||
text, label = example[0], example[1] | text, label = example[0], example[1] | ||||
x = TextField(text, False) | x = TextField(text, False) | ||||
y = TextField(label, is_target=True) | |||||
ins = Instance(word_seq=x, label_seq=y) | |||||
x_len = LabelField(len(text), is_target=False) | |||||
y = TextField(label, is_target=False) | |||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
data_set.append(ins) | data_set.append(ins) | ||||
data_set.index_field("word_seq", vocab) | data_set.index_field("word_seq", vocab) | ||||
data_set.index_field("label_seq", label_vocab) | |||||
data_set.index_field("truth", label_vocab) | |||||
model = SeqLabeling(args) | model = SeqLabeling(args) | ||||
@@ -9,10 +9,54 @@ input = [1,2,3] | |||||
text = "this is text" | text = "this is text" | ||||
doubles = 0.5 | |||||
doubles = 0.8 | |||||
tt = 0.5 | |||||
test = 105 | |||||
str = "this is a str" | |||||
double = 0.5 | |||||
[t] | [t] | ||||
x = "this is an test section" | x = "this is an test section" | ||||
[test-case-2] | [test-case-2] | ||||
double = 0.5 | double = 0.5 | ||||
doubles = 0.8 | |||||
tt = 0.5 | |||||
test = 105 | |||||
str = "this is a str" | |||||
[another-test] | |||||
doubles = 0.8 | |||||
tt = 0.5 | |||||
test = 105 | |||||
str = "this is a str" | |||||
double = 0.5 | |||||
[one-another-test] | |||||
doubles = 0.8 | |||||
tt = 0.5 | |||||
test = 105 | |||||
str = "this is a str" | |||||
double = 0.5 | |||||
@@ -31,7 +31,7 @@ class TestConfigLoader(unittest.TestCase): | |||||
return dict | return dict | ||||
test_arg = ConfigSection() | test_arg = ConfigSection() | ||||
ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | ||||
@@ -1,3 +1,4 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | ||||
@@ -14,28 +15,28 @@ class TestDatasetLoader(unittest.TestCase): | |||||
def test_case_TokenizeDatasetLoader(self): | def test_case_TokenizeDatasetLoader(self): | ||||
loader = TokenizeDataSetLoader() | loader = TokenizeDataSetLoader() | ||||
data = loader.load("test/data_for_tests/", max_seq_len=32) | |||||
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32) | |||||
print("pass TokenizeDataSetLoader test!") | print("pass TokenizeDataSetLoader test!") | ||||
def test_case_POSDatasetLoader(self): | def test_case_POSDatasetLoader(self): | ||||
loader = POSDataSetLoader() | loader = POSDataSetLoader() | ||||
data = loader.load() | |||||
datas = loader.load_lines() | |||||
data = loader.load("./test/data_for_tests/people.txt") | |||||
datas = loader.load_lines("./test/data_for_tests/people.txt") | |||||
print("pass POSDataSetLoader test!") | print("pass POSDataSetLoader test!") | ||||
def test_case_LMDatasetLoader(self): | def test_case_LMDatasetLoader(self): | ||||
loader = LMDataSetLoader() | loader = LMDataSetLoader() | ||||
data = loader.load() | |||||
datas = loader.load_lines() | |||||
data = loader.load("./test/data_for_tests/charlm.txt") | |||||
datas = loader.load_lines("./test/data_for_tests/charlm.txt") | |||||
print("pass TokenizeDataSetLoader test!") | print("pass TokenizeDataSetLoader test!") | ||||
def test_PeopleDailyCorpusLoader(self): | def test_PeopleDailyCorpusLoader(self): | ||||
loader = PeopleDailyCorpusLoader() | loader = PeopleDailyCorpusLoader() | ||||
_, _ = loader.load() | |||||
_, _ = loader.load("./test/data_for_tests/people_daily_raw.txt") | |||||
def test_ConllLoader(self): | def test_ConllLoader(self): | ||||
loader = ConllLoader("./test/data_for_tests/conll_example.txt") | |||||
_ = loader.load() | |||||
loader = ConllLoader() | |||||
_ = loader.load("./test/data_for_tests/conll_example.txt") | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
@@ -13,10 +13,10 @@ from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
cws_data_path = "test/data_for_tests/cws_pku_utf_8" | |||||
cws_data_path = "./test/data_for_tests/cws_pku_utf_8" | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
data_infer_path = "test/data_for_tests/people_infer.txt" | |||||
config_path = "test/data_for_tests/config" | |||||
data_infer_path = "./test/data_for_tests/people_infer.txt" | |||||
config_path = "./test/data_for_tests/config" | |||||
def infer(): | def infer(): | ||||
# Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
@@ -21,7 +21,7 @@ class TestConfigSaver(unittest.TestCase): | |||||
standard_section = ConfigSection() | standard_section = ConfigSection() | ||||
t_section = ConfigSection() | t_section = ConfigSection() | ||||
ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section}) | |||||
ConfigLoader().load_config(config_file_path, {"test": standard_section, "t": t_section}) | |||||
config_saver = ConfigSaver(config_file_path) | config_saver = ConfigSaver(config_file_path) | ||||
@@ -48,11 +48,11 @@ class TestConfigSaver(unittest.TestCase): | |||||
one_another_test_section = ConfigSection() | one_another_test_section = ConfigSection() | ||||
a_test_case_2_section = ConfigSection() | a_test_case_2_section = ConfigSection() | ||||
ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section, | |||||
"another-test": another_test_section, | |||||
"t": at_section, | |||||
"one-another-test": one_another_test_section, | |||||
"test-case-2": a_test_case_2_section}) | |||||
ConfigLoader().load_config(config_file_path, {"test": test_section, | |||||
"another-test": another_test_section, | |||||
"t": at_section, | |||||
"one-another-test": one_another_test_section, | |||||
"test-case-2": a_test_case_2_section}) | |||||
assert test_section == standard_section | assert test_section == standard_section | ||||
assert at_section == t_section | assert at_section == t_section | ||||
@@ -54,7 +54,7 @@ def mock_cws(): | |||||
class2id = Vocabulary(need_default=False) | class2id = Vocabulary(need_default=False) | ||||
label_list = ['B', 'M', 'E', 'S'] | label_list = ['B', 'M', 'E', 'S'] | ||||
class2id.update(label_list) | class2id.update(label_list) | ||||
save_pickle(class2id, "./mock/", "class2id.pkl") | |||||
save_pickle(class2id, "./mock/", "label2id.pkl") | |||||
model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)} | model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)} | ||||
config_file = """ | config_file = """ | ||||
@@ -115,7 +115,7 @@ def mock_pos_tag(): | |||||
idx2label = Vocabulary(need_default=False) | idx2label = Vocabulary(need_default=False) | ||||
label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv'] | label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv'] | ||||
idx2label.update(label_list) | idx2label.update(label_list) | ||||
save_pickle(idx2label, "./mock/", "class2id.pkl") | |||||
save_pickle(idx2label, "./mock/", "label2id.pkl") | |||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | ||||
config_file = """ | config_file = """ | ||||
@@ -163,7 +163,7 @@ def mock_text_classify(): | |||||
idx2label = Vocabulary(need_default=False) | idx2label = Vocabulary(need_default=False) | ||||
label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'] | label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'] | ||||
idx2label.update(label_list) | idx2label.update(label_list) | ||||
save_pickle(idx2label, "./mock/", "class2id.pkl") | |||||
save_pickle(idx2label, "./mock/", "label2id.pkl") | |||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | ||||
config_file = """ | config_file = """ | ||||