* refine code style * set up unit tests for Batch, DataSet, FieldArray * remove a lot of out-of-date unit tests, to get testing passedtags/v0.2.0
@@ -64,6 +64,7 @@ class DataSet(object): | |||||
""" | """ | ||||
:param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. | :param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. | ||||
All values must be of the same length. | |||||
If it is a list, it must be a list of Instance objects. | If it is a list, it must be a list of Instance objects. | ||||
""" | """ | ||||
self.field_arrays = {} | self.field_arrays = {} | ||||
@@ -23,8 +23,7 @@ class FieldArray(object): | |||||
self.dtype = None | self.dtype = None | ||||
def __repr__(self): | def __repr__(self): | ||||
# TODO | |||||
return '{}: {}'.format(self.name, self.content.__repr__()) | |||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||||
def append(self, val): | def append(self, val): | ||||
self.content.append(val) | self.content.append(val) | ||||
@@ -11,7 +11,7 @@ class Instance(object): | |||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
""" | """ | ||||
:param fields: a dict of (field name: field) | |||||
:param fields: a dict of (str: list). | |||||
""" | """ | ||||
self.fields = fields | self.fields = fields | ||||
@@ -1,5 +1,6 @@ | |||||
import os | |||||
import _pickle as pickle | import _pickle as pickle | ||||
import os | |||||
class BaseLoader(object): | class BaseLoader(object): | ||||
@@ -1,7 +1,6 @@ | |||||
import os | import os | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.field import * | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.io.base_loader import BaseLoader | from fastNLP.io.base_loader import BaseLoader | ||||
@@ -87,6 +86,7 @@ class DataSetLoader(BaseLoader): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@DataSet.set_reader('read_raw') | @DataSet.set_reader('read_raw') | ||||
class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -102,6 +102,7 @@ class RawDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
@DataSet.set_reader('read_pos') | @DataSet.set_reader('read_pos') | ||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
@@ -171,6 +172,7 @@ class POSDataSetLoader(DataSetLoader): | |||||
""" | """ | ||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_tokenize') | @DataSet.set_reader('read_tokenize') | ||||
class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
""" | """ | ||||
@@ -230,6 +232,7 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_class') | @DataSet.set_reader('read_class') | ||||
class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
@@ -268,6 +271,7 @@ class ClassDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
@DataSet.set_reader('read_conll') | @DataSet.set_reader('read_conll') | ||||
class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
"""loader for conll format files""" | """loader for conll format files""" | ||||
@@ -309,6 +313,7 @@ class ConllLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
pass | pass | ||||
@DataSet.set_reader('read_lm') | @DataSet.set_reader('read_lm') | ||||
class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
"""Language Model Dataset Loader | """Language Model Dataset Loader | ||||
@@ -345,6 +350,7 @@ class LMDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
pass | pass | ||||
@DataSet.set_reader('read_people_daily') | @DataSet.set_reader('read_people_daily') | ||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
@@ -1,6 +1,9 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.dataset import construct_dataset | from fastNLP.core.dataset import construct_dataset | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
@@ -10,9 +13,21 @@ class TestCase1(unittest.TestCase): | |||||
dataset = construct_dataset( | dataset = construct_dataset( | ||||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | ||||
dataset.set_target() | dataset.set_target() | ||||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False) | |||||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
cnt = 0 | cnt = 0 | ||||
for _, _ in batch: | for _, _ in batch: | ||||
cnt += 1 | cnt += 1 | ||||
self.assertEqual(cnt, 10) | self.assertEqual(cnt, 10) | ||||
def test_dataset_batching(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ds.set_input(x=True) | |||||
ds.set_target(y=True) | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | |||||
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | |||||
self.assertEqual(len(x["x"]), 4) | |||||
self.assertEqual(len(y["y"]), 4) | |||||
self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | |||||
self.assertListEqual(list(y["y"][-1]), [5, 6]) |
@@ -1,20 +1,75 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | |||||
class TestDataSet(unittest.TestCase): | class TestDataSet(unittest.TestCase): | ||||
def test_case_1(self): | |||||
ds = DataSet() | |||||
ds.add_field(name="xx", fields=["a", "b", "e", "d"]) | |||||
def test_init_v1(self): | |||||
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||||
self.assertTrue("xx" in ds.field_arrays) | |||||
self.assertEqual(len(ds.field_arrays["xx"]), 4) | |||||
self.assertEqual(ds.get_length(), 4) | |||||
self.assertEqual(ds.get_fields(), ds.field_arrays) | |||||
def test_init_v2(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||||
try: | |||||
ds.add_field(name="yy", fields=["x", "y", "z", "w", "f"]) | |||||
except BaseException as e: | |||||
self.assertTrue(isinstance(e, AssertionError)) | |||||
def test_init_assert(self): | |||||
with self.assertRaises(AssertionError): | |||||
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) | |||||
with self.assertRaises(AssertionError): | |||||
_ = DataSet([[1, 2, 3, 4]] * 10) | |||||
with self.assertRaises(ValueError): | |||||
_ = DataSet(0.00001) | |||||
def test_append(self): | |||||
dd = DataSet() | |||||
for _ in range(3): | |||||
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) | |||||
self.assertEqual(len(dd), 3) | |||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | |||||
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | |||||
def test_add_append(self): | |||||
dd = DataSet() | |||||
dd.add_field("x", [[1, 2, 3]] * 10) | |||||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||||
dd.add_field("z", [[5, 6]] * 10) | |||||
self.assertEqual(len(dd), 10) | |||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) | |||||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | |||||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | |||||
def test_delete_field(self): | |||||
dd = DataSet() | |||||
dd.add_field("x", [[1, 2, 3]] * 10) | |||||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||||
dd.delete_field("x") | |||||
self.assertFalse("x" in dd.field_arrays) | |||||
self.assertTrue("y" in dd.field_arrays) | |||||
def test_getitem(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ins_1, ins_0 = ds[0], ds[1] | |||||
self.assertTrue(isinstance(ins_1, DataSet.Instance) and isinstance(ins_0, DataSet.Instance)) | |||||
self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | |||||
self.assertEqual(ins_1["y"], [5, 6]) | |||||
self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | |||||
self.assertEqual(ins_0["y"], [5, 6]) | |||||
sub_ds = ds[:10] | |||||
self.assertTrue(isinstance(sub_ds, DataSet)) | |||||
self.assertEqual(len(sub_ds), 10) | |||||
field = ds["x"] | |||||
self.assertEqual(field, ds.field_arrays["x"]) | |||||
def test_apply(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | |||||
self.assertTrue("rx" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) |
@@ -1,6 +1,22 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
from fastNLP.core.fieldarray import FieldArray | |||||
class TestFieldArray(unittest.TestCase): | class TestFieldArray(unittest.TestCase): | ||||
def test(self): | def test(self): | ||||
pass | |||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | |||||
self.assertEqual(len(fa), 5) | |||||
fa.append(6) | |||||
self.assertEqual(len(fa), 6) | |||||
self.assertEqual(fa[-1], 6) | |||||
self.assertEqual(fa[0], 1) | |||||
fa[-1] = 60 | |||||
self.assertEqual(fa[-1], 60) | |||||
self.assertEqual(fa.get(0), 1) | |||||
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | |||||
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) |
@@ -1,100 +0,0 @@ | |||||
import os | |||||
import sys | |||||
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | |||||
from fastNLP.core import metrics | |||||
# from sklearn import metrics as skmetrics | |||||
import unittest | |||||
from numpy import random | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
import torch | |||||
def generate_fake_label(low, high, size): | |||||
return random.randint(low, high, size), random.randint(low, high, size) | |||||
class TestEvaluator(unittest.TestCase): | |||||
def test_a(self): | |||||
evaluator = SeqLabelEvaluator() | |||||
pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] | |||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}] | |||||
ans = evaluator(pred, truth) | |||||
print(ans) | |||||
def test_b(self): | |||||
evaluator = SeqLabelEvaluator() | |||||
pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]] | |||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}] | |||||
ans = evaluator(pred, truth) | |||||
print(ans) | |||||
class TestMetrics(unittest.TestCase): | |||||
delta = 1e-5 | |||||
# test for binary, multiclass, multilabel | |||||
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | |||||
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | |||||
def test_accuracy_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
for normalize in [True, False]: | |||||
for sample_weight in [None, random.rand(y_true.shape[0])]: | |||||
test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||||
# ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||||
# self.assertAlmostEqual(test, ans, delta=self.delta) | |||||
def test_recall_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.recall_score(y_true, y_pred, labels=labels, average=None) | |||||
if not isinstance(test, list): | |||||
test = list(test) | |||||
# ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans = list(ans) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.recall_score(y_true, y_pred) | |||||
# ans = skmetrics.recall_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
def test_precision_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) | |||||
# ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans, test = list(ans), list(test) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.precision_score(y_true, y_pred) | |||||
# ans = skmetrics.precision_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
def test_f1_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.f1_score(y_true, y_pred, labels=labels, average=None) | |||||
# ans = skmetrics.f1_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans, test = list(ans), list(test) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.f1_score(y_true, y_pred) | |||||
# ans = skmetrics.f1_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -1,77 +1,6 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.predictor import Predictor | |||||
from fastNLP.core.utils import save_pickle | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.dataset_loader import convert_seq_dataset | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
def test_seq_label(self): | |||||
model_args = { | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5 | |||||
} | |||||
infer_data = [ | |||||
['a', 'b', 'c', 'd', 'e'], | |||||
['a', '@', 'c', 'd', 'e'], | |||||
['a', 'b', '#', 'd', 'e'], | |||||
['a', 'b', 'c', '?', 'e'], | |||||
['a', 'b', 'c', 'd', '$'], | |||||
['!', 'b', 'c', 'd', 'e'] | |||||
] | |||||
vocab = Vocabulary() | |||||
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
class_vocab = Vocabulary() | |||||
class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||||
os.system("mkdir save") | |||||
save_pickle(class_vocab, "./save/", "label2id.pkl") | |||||
save_pickle(vocab, "./save/", "word2id.pkl") | |||||
model = CNNText(model_args) | |||||
import fastNLP.core.predictor as pre | |||||
predictor = Predictor("./save/", pre.text_classify_post_processor) | |||||
# Load infer data | |||||
infer_data_set = convert_seq_dataset(infer_data) | |||||
infer_data_set.index_field("word_seq", vocab) | |||||
results = predictor.predict(network=model, data=infer_data_set) | |||||
self.assertTrue(isinstance(results, list)) | |||||
self.assertGreater(len(results), 0) | |||||
self.assertEqual(len(results), len(infer_data)) | |||||
for res in results: | |||||
self.assertTrue(isinstance(res, str)) | |||||
self.assertTrue(res in class_vocab.word2idx) | |||||
del model, predictor | |||||
infer_data_set.set_origin_len("word_seq") | |||||
model = SeqLabeling(model_args) | |||||
predictor = Predictor("./save/", pre.seq_label_post_processor) | |||||
results = predictor.predict(network=model, data=infer_data_set) | |||||
self.assertTrue(isinstance(results, list)) | |||||
self.assertEqual(len(results), len(infer_data)) | |||||
for i in range(len(infer_data)): | |||||
res = results[i] | |||||
self.assertTrue(isinstance(res, list)) | |||||
self.assertEqual(len(res), len(infer_data[i])) | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
class TestPredictor2(unittest.TestCase): | |||||
def test_text_classify(self): | |||||
# TODO | |||||
def test(self): | |||||
pass | pass |
@@ -1,57 +1,9 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
class TestTester(unittest.TestCase): | class TestTester(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
model_args = { | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5 | |||||
} | |||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
"save_loss": True, "batch_size": 2, "pickle_path": "./save/", | |||||
"use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()} | |||||
train_data = [ | |||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
] | |||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
data_set = DataSet() | |||||
for example in train_data: | |||||
text, label = example[0], example[1] | |||||
x = TextField(text, False) | |||||
x_len = LabelField(len(text), is_target=False) | |||||
y = TextField(label, is_target=True) | |||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
data_set.append(ins) | |||||
data_set.index_field("word_seq", vocab) | |||||
data_set.index_field("truth", label_vocab) | |||||
model = SeqLabeling(model_args) | |||||
tester = Tester(**valid_args) | |||||
tester.test(network=model, dev_data=data_set) | |||||
# If this can run, everything is OK. | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
pass |
@@ -1,57 +1,6 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
class TestTrainer(unittest.TestCase): | class TestTrainer(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||||
"save_best_dev": True, "model_name": "default_model_name.pkl", | |||||
"loss": Loss("cross_entropy"), | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5, | |||||
"evaluator": SeqLabelEvaluator() | |||||
} | |||||
trainer = Trainer(**args) | |||||
train_data = [ | |||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
] | |||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
data_set = DataSet() | |||||
for example in train_data: | |||||
text, label = example[0], example[1] | |||||
x = TextField(text, False) | |||||
x_len = LabelField(len(text), is_target=False) | |||||
y = TextField(label, is_target=False) | |||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
data_set.append(ins) | |||||
data_set.index_field("word_seq", vocab) | |||||
data_set.index_field("truth", label_vocab) | |||||
model = SeqLabeling(args) | |||||
trainer.train(network=model, train_data=data_set, dev_data=data_set) | |||||
# If this can run, everything is OK. | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
pass |
@@ -1,53 +0,0 @@ | |||||
import configparser | |||||
import json | |||||
import os | |||||
import unittest | |||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
class TestConfigLoader(unittest.TestCase): | |||||
def test_case_ConfigLoader(self): | |||||
def read_section_from_config(config_path, section_name): | |||||
dict = {} | |||||
if not os.path.exists(config_path): | |||||
raise FileNotFoundError("config file {} NOT found.".format(config_path)) | |||||
cfg = configparser.ConfigParser() | |||||
cfg.read(config_path) | |||||
if section_name not in cfg: | |||||
raise AttributeError("config file {} do NOT have section {}".format( | |||||
config_path, section_name | |||||
)) | |||||
gen_sec = cfg[section_name] | |||||
for s in gen_sec.keys(): | |||||
try: | |||||
val = json.loads(gen_sec[s]) | |||||
dict[s] = val | |||||
except Exception as e: | |||||
raise AttributeError("json can NOT load {} in section {}, config file {}".format( | |||||
s, section_name, config_path | |||||
)) | |||||
return dict | |||||
test_arg = ConfigSection() | |||||
ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | |||||
for sec in section: | |||||
if (sec not in test_arg) or (section[sec] != test_arg[sec]): | |||||
raise AttributeError("ERROR") | |||||
for sec in test_arg.__dict__.keys(): | |||||
if (sec not in section) or (section[sec] != test_arg[sec]): | |||||
raise AttributeError("ERROR") | |||||
try: | |||||
not_exist = test_arg["NOT EXIST"] | |||||
except Exception as e: | |||||
pass | |||||
print("pass config test!") | |||||
@@ -7,7 +7,7 @@ from fastNLP.io.config_saver import ConfigSaver | |||||
class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
config_file_dir = "test/loader/" | |||||
config_file_dir = "test/io/" | |||||
config_file_name = "config" | config_file_name = "config" | ||||
config_file_path = os.path.join(config_file_dir, config_file_name) | config_file_path = os.path.join(config_file_dir, config_file_name) | ||||
@@ -1,53 +0,0 @@ | |||||
import unittest | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.io.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | |||||
PeopleDailyCorpusLoader, ConllLoader | |||||
class TestDatasetLoader(unittest.TestCase): | |||||
def test_case_1(self): | |||||
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | |||||
lines = data.split("\n") | |||||
answer = POSDataSetLoader.parse(lines) | |||||
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | |||||
self.assertListEqual(answer, truth, "POS Dataset Loader") | |||||
def test_case_TokenizeDatasetLoader(self): | |||||
loader = TokenizeDataSetLoader() | |||||
filepath = "./test/data_for_tests/cws_pku_utf_8" | |||||
data = loader.load(filepath, max_seq_len=32) | |||||
assert len(data) > 0 | |||||
data1 = DataSet() | |||||
data1.read_tokenize(filepath, max_seq_len=32) | |||||
assert len(data1) > 0 | |||||
print("pass TokenizeDataSetLoader test!") | |||||
def test_case_POSDatasetLoader(self): | |||||
loader = POSDataSetLoader() | |||||
filepath = "./test/data_for_tests/people.txt" | |||||
data = loader.load("./test/data_for_tests/people.txt") | |||||
datas = loader.load_lines("./test/data_for_tests/people.txt") | |||||
data1 = DataSet().read_pos(filepath) | |||||
assert len(data1) > 0 | |||||
print("pass POSDataSetLoader test!") | |||||
def test_case_LMDatasetLoader(self): | |||||
loader = LMDataSetLoader() | |||||
data = loader.load("./test/data_for_tests/charlm.txt") | |||||
datas = loader.load_lines("./test/data_for_tests/charlm.txt") | |||||
print("pass TokenizeDataSetLoader test!") | |||||
def test_PeopleDailyCorpusLoader(self): | |||||
loader = PeopleDailyCorpusLoader() | |||||
_, _ = loader.load("./test/data_for_tests/people_daily_raw.txt") | |||||
def test_ConllLoader(self): | |||||
loader = ConllLoader() | |||||
_ = loader.load("./test/data_for_tests/conll_example.txt") | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -1,31 +0,0 @@ | |||||
import os | |||||
import unittest | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
class TestEmbedLoader(unittest.TestCase): | |||||
glove_path = './test/data_for_tests/glove.6B.50d_test.txt' | |||||
pkl_path = './save' | |||||
raw_texts = ["i am a cat", | |||||
"this is a test of new batch", | |||||
"ha ha", | |||||
"I am a good boy .", | |||||
"This is the most beautiful girl ." | |||||
] | |||||
texts = [text.strip().split() for text in raw_texts] | |||||
vocab = Vocabulary() | |||||
vocab.update(texts) | |||||
def test1(self): | |||||
emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
self.assertTrue(emb.shape[0] == (len(self.vocab))) | |||||
self.assertTrue(emb.shape[1] == 50) | |||||
os.remove(self.pkl_path) | |||||
def test2(self): | |||||
try: | |||||
_ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
self.fail(msg="load dismatch embedding") | |||||
except ValueError: | |||||
pass |
@@ -1,150 +0,0 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append("..") | |||||
import argparse | |||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.dataset_loader import BaseLoader | |||||
from fastNLP.io.model_saver import ModelSaver | |||||
from fastNLP.io.model_loader import ModelLoader | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.utils import save_pickle, load_pickle | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | |||||
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt", | |||||
help="path to the training data") | |||||
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | |||||
parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt", | |||||
help="data used for inference") | |||||
args = parser.parse_args() | |||||
pickle_path = args.save | |||||
model_name = args.model_name | |||||
config_dir = args.config | |||||
data_path = args.train | |||||
data_infer_path = args.infer | |||||
def infer(): | |||||
# Load infer configuration, the same as test | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, {"POS_infer": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word_vocab = load_pickle(pickle_path, "word2id.pkl") | |||||
label_vocab = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["vocab_size"] = len(word_vocab) | |||||
test_args["num_classes"] = len(label_vocab) | |||||
print("vocabularies loaded") | |||||
# Define the same model | |||||
model = SeqLabeling(test_args) | |||||
print("model defined") | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
print("model loaded!") | |||||
# Data Loader | |||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | |||||
print("data set prepared") | |||||
# Inference interface | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
for res in results: | |||||
print(res) | |||||
print("Inference finished!") | |||||
def train_and_test(): | |||||
# Config Loader | |||||
trainer_args = ConfigSection() | |||||
model_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, { | |||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
data_set = SeqLabelDataSet() | |||||
data_set.load(data_path) | |||||
train_set, dev_set = data_set.split(0.3, shuffle=True) | |||||
model_args["vocab_size"] = len(data_set.word_vocab) | |||||
model_args["num_classes"] = len(data_set.label_vocab) | |||||
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||||
""" | |||||
trainer = SeqLabelTrainer( | |||||
epochs=trainer_args["epochs"], | |||||
batch_size=trainer_args["batch_size"], | |||||
validate=False, | |||||
use_cuda=trainer_args["use_cuda"], | |||||
pickle_path=pickle_path, | |||||
save_best_dev=trainer_args["save_best_dev"], | |||||
model_name=model_name, | |||||
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
) | |||||
""" | |||||
# Model | |||||
model = SeqLabeling(model_args) | |||||
model.fit(train_set, dev_set, | |||||
epochs=trainer_args["epochs"], | |||||
batch_size=trainer_args["batch_size"], | |||||
validate=False, | |||||
use_cuda=trainer_args["use_cuda"], | |||||
pickle_path=pickle_path, | |||||
save_best_dev=trainer_args["save_best_dev"], | |||||
model_name=model_name, | |||||
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9)) | |||||
# Start training | |||||
# trainer.train(model, train_set, dev_set) | |||||
print("Training finished!") | |||||
# Saver | |||||
saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
del model | |||||
change_field_is_target(dev_set, "truth", True) | |||||
# Define the same model | |||||
model = SeqLabeling(model_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
print("model loaded!") | |||||
# Load test configuration | |||||
tester_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
# Tester | |||||
tester = SeqLabelTester(batch_size=4, | |||||
use_cuda=False, | |||||
pickle_path=pickle_path, | |||||
model_name="seq_label_in_test.pkl", | |||||
evaluator=SeqLabelEvaluator() | |||||
) | |||||
# Start testing with validation data | |||||
tester.test(model, dev_set) | |||||
print("model tested!") | |||||
if __name__ == "__main__": | |||||
train_and_test() | |||||
infer() |
@@ -1,25 +0,0 @@ | |||||
import unittest | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.models.char_language_model import CharLM | |||||
class TestCharLM(unittest.TestCase): | |||||
def test_case_1(self): | |||||
char_emb_dim = 50 | |||||
word_emb_dim = 50 | |||||
vocab_size = 1000 | |||||
num_char = 24 | |||||
max_word_len = 21 | |||||
num_seq = 64 | |||||
seq_len = 32 | |||||
model = CharLM(char_emb_dim, word_emb_dim, vocab_size, num_char) | |||||
x = torch.from_numpy(np.random.randint(0, num_char, size=(num_seq, seq_len, max_word_len + 2))) | |||||
self.assertEqual(tuple(x.shape), (num_seq, seq_len, max_word_len + 2)) | |||||
y = model(x) | |||||
self.assertEqual(tuple(y.shape), (num_seq * seq_len, vocab_size)) |
@@ -1,111 +0,0 @@ | |||||
import os | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.predictor import Predictor | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.core.utils import save_pickle, load_pickle | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.dataset_loader import TokenizeDataSetLoader, RawDataSetLoader | |||||
from fastNLP.io.model_loader import ModelLoader | |||||
from fastNLP.io.model_saver import ModelSaver | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
data_name = "pku_training.utf8" | |||||
cws_data_path = "./test/data_for_tests/cws_pku_utf_8" | |||||
pickle_path = "./save/" | |||||
data_infer_path = "./test/data_for_tests/people_infer.txt" | |||||
config_path = "./test/data_for_tests/config" | |||||
def infer(): | |||||
# Load infer configuration, the same as test | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# Define the same model | |||||
model = SeqLabeling(test_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print("model loaded!") | |||||
# Load infer data | |||||
infer_data = RawDataSetLoader().load(data_infer_path) | |||||
infer_data.index_field("word_seq", word2index) | |||||
infer_data.set_origin_len("word_seq") | |||||
# inference | |||||
infer = Predictor(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
print(results) | |||||
def train_test(): | |||||
# Config Loader | |||||
train_args = ConfigSection() | |||||
ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | |||||
# define dataset | |||||
data_train = TokenizeDataSetLoader().load(cws_data_path) | |||||
word_vocab = Vocabulary() | |||||
label_vocab = Vocabulary() | |||||
data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||||
data_train.set_origin_len("word_seq") | |||||
data_train.rename_field("label_seq", "truth").set_target(truth=False) | |||||
train_args["vocab_size"] = len(word_vocab) | |||||
train_args["num_classes"] = len(label_vocab) | |||||
save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||||
# Trainer | |||||
trainer = Trainer(**train_args.data) | |||||
# Model | |||||
model = SeqLabeling(train_args) | |||||
# Start training | |||||
trainer.train(model, data_train) | |||||
# Saver | |||||
saver = ModelSaver("./save/saved_model.pkl") | |||||
saver.save_pytorch(model) | |||||
del model, trainer | |||||
# Define the same model | |||||
model = SeqLabeling(train_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
# Load test configuration | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
test_args["evaluator"] = SeqLabelEvaluator() | |||||
# Tester | |||||
tester = Tester(**test_args.data) | |||||
# Start testing | |||||
data_train.set_target(truth=True) | |||||
tester.test(model, data_train) | |||||
def test(): | |||||
os.makedirs("save", exist_ok=True) | |||||
train_test() | |||||
infer() | |||||
os.system("rm -rf save") | |||||
if __name__ == "__main__": | |||||
train_test() | |||||
infer() |
@@ -1,90 +0,0 @@ | |||||
import os | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.core.utils import save_pickle | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.dataset_loader import TokenizeDataSetLoader | |||||
from fastNLP.io.model_loader import ModelLoader | |||||
from fastNLP.io.model_saver import ModelSaver | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
pickle_path = "./seq_label/" | |||||
model_name = "seq_label_model.pkl" | |||||
config_dir = "../data_for_tests/config" | |||||
data_path = "../data_for_tests/people.txt" | |||||
data_infer_path = "../data_for_tests/people_infer.txt" | |||||
def test_training(): | |||||
# Config Loader | |||||
trainer_args = ConfigSection() | |||||
model_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, { | |||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
data_set = TokenizeDataSetLoader().load(data_path) | |||||
word_vocab = Vocabulary() | |||||
label_vocab = Vocabulary() | |||||
data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||||
data_set.set_origin_len("word_seq") | |||||
data_set.rename_field("label_seq", "truth").set_target(truth=False) | |||||
data_train, data_dev = data_set.split(0.3, shuffle=True) | |||||
model_args["vocab_size"] = len(word_vocab) | |||||
model_args["num_classes"] = len(label_vocab) | |||||
save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||||
trainer = Trainer( | |||||
epochs=trainer_args["epochs"], | |||||
batch_size=trainer_args["batch_size"], | |||||
validate=False, | |||||
use_cuda=False, | |||||
pickle_path=pickle_path, | |||||
save_best_dev=trainer_args["save_best_dev"], | |||||
model_name=model_name, | |||||
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
) | |||||
# Model | |||||
model = SeqLabeling(model_args) | |||||
# Start training | |||||
trainer.train(model, data_train, data_dev) | |||||
# Saver | |||||
saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
saver.save_pytorch(model) | |||||
del model, trainer | |||||
# Define the same model | |||||
model = SeqLabeling(model_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
# Load test configuration | |||||
tester_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
# Tester | |||||
tester = Tester(batch_size=4, | |||||
use_cuda=False, | |||||
pickle_path=pickle_path, | |||||
model_name="seq_label_in_test.pkl", | |||||
evaluator=SeqLabelEvaluator() | |||||
) | |||||
# Start testing with validation data | |||||
data_dev.set_target(truth=True) | |||||
tester.test(model, data_dev) | |||||
if __name__ == "__main__": | |||||
test_training() |
@@ -1,107 +0,0 @@ | |||||
# Python: 3.5 | |||||
# encoding: utf-8 | |||||
import argparse | |||||
import os | |||||
import sys | |||||
sys.path.append("..") | |||||
from fastNLP.core.predictor import ClassificationInfer | |||||
from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.dataset_loader import ClassDataSetLoader | |||||
from fastNLP.io.model_loader import ModelLoader | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
from fastNLP.io.model_saver import ModelSaver | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.dataset import TextClassifyDataSet | |||||
from fastNLP.core.utils import save_pickle, load_pickle | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | |||||
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt", | |||||
help="path to the training data") | |||||
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | |||||
args = parser.parse_args() | |||||
save_dir = args.save | |||||
train_data_dir = args.train | |||||
model_name = args.model_name | |||||
config_dir = args.config | |||||
def infer(): | |||||
# load dataset | |||||
print("Loading data...") | |||||
word_vocab = load_pickle(save_dir, "word2id.pkl") | |||||
label_vocab = load_pickle(save_dir, "label2id.pkl") | |||||
print("vocabulary size:", len(word_vocab)) | |||||
print("number of classes:", len(label_vocab)) | |||||
infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | |||||
model_args = ConfigSection() | |||||
model_args["vocab_size"] = len(word_vocab) | |||||
model_args["num_classes"] = len(label_vocab) | |||||
ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
# construct model | |||||
print("Building model...") | |||||
cnn = CNNText(model_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name)) | |||||
print("model loaded!") | |||||
infer = ClassificationInfer(pickle_path=save_dir) | |||||
results = infer.predict(cnn, infer_data) | |||||
print(results) | |||||
def train(): | |||||
train_args, model_args = ConfigSection(), ConfigSection() | |||||
ConfigLoader.load_config(config_dir, {"text_class": train_args}) | |||||
# load dataset | |||||
print("Loading data...") | |||||
data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
data.load(train_data_dir) | |||||
print("vocabulary size:", len(data.word_vocab)) | |||||
print("number of classes:", len(data.label_vocab)) | |||||
save_pickle(data.word_vocab, save_dir, "word2id.pkl") | |||||
save_pickle(data.label_vocab, save_dir, "label2id.pkl") | |||||
model_args["num_classes"] = len(data.label_vocab) | |||||
model_args["vocab_size"] = len(data.word_vocab) | |||||
# construct model | |||||
print("Building model...") | |||||
model = CNNText(model_args) | |||||
# train | |||||
print("Training...") | |||||
trainer = ClassificationTrainer(epochs=train_args["epochs"], | |||||
batch_size=train_args["batch_size"], | |||||
validate=train_args["validate"], | |||||
use_cuda=train_args["use_cuda"], | |||||
pickle_path=save_dir, | |||||
save_best_dev=train_args["save_best_dev"], | |||||
model_name=model_name, | |||||
loss=Loss("cross_entropy"), | |||||
optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | |||||
trainer.train(model, data) | |||||
print("Training finished!") | |||||
saver = ModelSaver(os.path.join(save_dir, model_name)) | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
if __name__ == "__main__": | |||||
train() | |||||
infer() |
@@ -14,7 +14,7 @@ class TestGroupNorm(unittest.TestCase): | |||||
class TestLayerNormalization(unittest.TestCase): | class TestLayerNormalization(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
ln = LayerNormalization(d_hid=5, eps=2e-3) | |||||
ln = LayerNormalization(layer_size=5, eps=2e-3) | |||||
x = torch.randn((20, 50, 5)) | x = torch.randn((20, 50, 5)) | ||||
y = ln(x) | y = ln(x) | ||||